1
1
# ' SHAP Interaction Plot
2
2
# '
3
- # ' Plots a beeswarm plot for each feature pair. Diagonals represent the main effects,
4
- # ' while off-diagonals show interactions (multiplied by two due to symmetry).
5
- # ' The colors on the beeswarm plots represent min-max scaled feature values.
3
+ # ' @description
4
+ # ' Creates a beeswarm plot or a barplot of SHAP interaction values/main effects.
5
+ # '
6
+ # ' In the beeswarm plot (`kind = "beeswarm"`), diagonals represent the main effects,
7
+ # ' while off-diagonals show SHAP interactions (multiplied by two due to symmetry).
8
+ # ' The color axis represent min-max scaled feature values.
6
9
# ' Non-numeric features are transformed to numeric by calling [data.matrix()] first.
7
10
# ' The features are sorted in decreasing order of usual SHAP importance.
8
11
# '
12
+ # ' The barplot (`kind = "bar"`) shows average absolute SHAP interaction values
13
+ # ' and main effects for each feature pair.
14
+ # ' Again, due to symmetry, the interaction values are multiplied by two.
15
+ # '
9
16
# ' @param object An object of class "(m)shapviz" containing element `S_inter`.
10
17
# ' @param kind Set to "no" to return the matrix of average absolute SHAP
11
18
# ' interactions (or a list of such matrices in case of object of class "mshapviz").
19
26
# ' absolute SHAP values (or a list of such matrices in case of "mshapviz" object).
20
27
# ' @examples
21
28
# ' dtrain <- xgboost::xgb.DMatrix(
22
- # ' data.matrix(iris[, -1]), label = iris[, 1], nthread = 1
29
+ # ' data.matrix(iris[, -1]),
30
+ # ' label = iris[, 1], nthread = 1
23
31
# ' )
24
32
# ' fit <- xgboost::xgb.train(data = dtrain, nrounds = 10, nthread = 1)
25
33
# ' x <- shapviz(fit, X_pred = dtrain, X = iris, interactions = TRUE)
26
34
# ' sv_interaction(x, kind = "no")
27
35
# ' sv_interaction(x, max_display = 2, size = 3)
36
+ # ' sv_interaction(x, kind = "bar")
28
37
# ' @seealso [sv_importance()]
29
38
# ' @export
30
39
sv_interaction <- function (object , ... ) {
@@ -41,47 +50,78 @@ sv_interaction.default <- function(object, ...) {
41
50
# ' @describeIn sv_interaction
42
51
# ' SHAP interaction plot for an object of class "shapviz".
43
52
# ' @export
44
- sv_interaction.shapviz <- function (object , kind = c(" beeswarm" , " no" ),
45
- max_display = 7L , alpha = 0.3 ,
46
- bee_width = 0.3 , bee_adjust = 0.5 ,
47
- viridis_args = getOption(" shapviz.viridis_args" ),
48
- color_bar_title = " Row feature value" ,
49
- sort_features = TRUE , ... ) {
53
+ sv_interaction.shapviz <- function (
54
+ object ,
55
+ kind = c(" beeswarm" , " bar" , " no" ),
56
+ max_display = 15L - 8 * (kind == " beeswarm" ),
57
+ alpha = 0.3 ,
58
+ bee_width = 0.3 ,
59
+ bee_adjust = 0.5 ,
60
+ viridis_args = getOption(" shapviz.viridis_args" ),
61
+ color_bar_title = " Row feature value" ,
62
+ sort_features = TRUE ,
63
+ fill = " #fca50a" ,
64
+ bar_width = 2 / 3 ,
65
+ ... ) {
50
66
kind <- match.arg(kind )
51
67
if (is.null(get_shap_interactions(object ))) {
52
68
stop(" No SHAP interaction values available." )
53
69
}
70
+
71
+ # Sort features by SHAP importance first (irrelevant for kind = "bee")
54
72
ord <- names(.get_imp(get_shap_values(object ), sort_features = sort_features ))
55
73
object <- object [, ord ]
56
74
75
+ # Calculate average absolute SHAP interactions
76
+ M <- apply(abs(get_shap_interactions(object )), MARGIN = 2 : 3 , FUN = mean )
77
+ M <- M + t(M ) - diag(diag(M )) # Off-diagonals twice
78
+
57
79
if (kind == " no" ) {
58
- mat <- apply(abs(get_shap_interactions(object )), 2 : 3 , mean )
59
- off_diag <- row(mat ) != col(mat )
60
- mat [off_diag ] <- mat [off_diag ] * 2 # compensate symmetry
61
- return (mat )
80
+ return (M )
81
+ }
82
+
83
+ if (kind == " bar" ) {
84
+ # Turn to long format and make feature pair names
85
+ imp_df <- transform(
86
+ as.data.frame.table(M , responseName = " value" ),
87
+ feature = ifelse(Var1 == Var2 , as.character(Var1 ), paste(Var1 , Var2 , sep = " :" ))
88
+ )
89
+ if (sort_features ) {
90
+ imp_df <- imp_df [order(imp_df $ value , decreasing = TRUE ), ]
91
+ imp_df <- transform(imp_df , feature = factor (feature , levels = rev(feature )))
92
+ }
93
+ if (nrow(imp_df ) > max_display ) {
94
+ imp_df <- imp_df [seq_len(max_display ), ]
95
+ }
96
+
97
+ p <- ggplot2 :: ggplot(imp_df , ggplot2 :: aes(x = value , y = feature )) +
98
+ ggplot2 :: geom_bar(fill = fill , width = bar_width , stat = " identity" , ... ) +
99
+ ggplot2 :: labs(x = " mean(|SHAP interaction value|)" , y = ggplot2 :: element_blank())
100
+
101
+ return (p )
62
102
}
63
103
104
+ # kind == "bee"
64
105
if (ncol(object ) > max_display ) {
65
106
ord <- ord [seq_len(max_display )]
66
107
object <- object [, ord ]
67
108
}
68
109
69
- # Prepare data.frame for beeswarm
70
110
S_inter <- get_shap_interactions(object )
71
111
X <- .scale_X(get_feature_values(object ))
72
112
X_long <- as.data.frame.table(X )
73
113
df <- transform(
74
114
as.data.frame.table(S_inter , responseName = " value" ),
75
115
Variable1 = factor (Var2 , levels = ord ),
76
116
Variable2 = factor (Var3 , levels = ord ),
77
- color = X_long $ Freq # Correctly recycled along the third dimension of S_inter
117
+ color = X_long $ Freq # Correctly recycled along the third dimension of S_inter
78
118
)
79
119
80
120
# Compensate symmetry
81
121
mask <- df [[" Variable1" ]] != df [[" Variable2" ]]
82
122
df [mask , " value" ] <- 2 * df [mask , " value" ]
83
123
84
- ggplot2 :: ggplot(df , ggplot2 :: aes(x = value , y = " 1" )) +
124
+ p <- ggplot2 :: ggplot(df , ggplot2 :: aes(x = value , y = " 1" )) +
85
125
ggplot2 :: geom_vline(xintercept = 0 , color = " darkgray" ) +
86
126
ggplot2 :: geom_point(
87
127
ggplot2 :: aes(color = color ),
@@ -104,17 +144,25 @@ sv_interaction.shapviz <- function(object, kind = c("beeswarm", "no"),
104
144
axis.ticks.y = ggplot2 :: element_blank(),
105
145
axis.text.y = ggplot2 :: element_blank()
106
146
)
147
+ return (p )
107
148
}
108
149
109
150
# ' @describeIn sv_interaction
110
151
# ' SHAP interaction plot for an object of class "mshapviz".
111
152
# ' @export
112
- sv_interaction.mshapviz <- function (object , kind = c(" beeswarm" , " no" ),
113
- max_display = 7L , alpha = 0.3 ,
114
- bee_width = 0.3 , bee_adjust = 0.5 ,
115
- viridis_args = getOption(" shapviz.viridis_args" ),
116
- color_bar_title = " Row feature value" ,
117
- sort_features = TRUE , ... ) {
153
+ sv_interaction.mshapviz <- function (
154
+ object ,
155
+ kind = c(" beeswarm" , " bar" , " no" ),
156
+ max_display = 7L ,
157
+ alpha = 0.3 ,
158
+ bee_width = 0.3 ,
159
+ bee_adjust = 0.5 ,
160
+ viridis_args = getOption(" shapviz.viridis_args" ),
161
+ color_bar_title = " Row feature value" ,
162
+ sort_features = TRUE ,
163
+ fill = " #fca50a" ,
164
+ bar_width = 2 / 3 ,
165
+ ... ) {
118
166
kind <- match.arg(kind )
119
167
120
168
plot_list <- lapply(
@@ -129,11 +177,15 @@ sv_interaction.mshapviz <- function(object, kind = c("beeswarm", "no"),
129
177
viridis_args = viridis_args ,
130
178
color_bar_title = color_bar_title ,
131
179
sort_features = sort_features ,
180
+ fill = fill ,
181
+ bar_width = bar_width ,
132
182
...
133
183
)
134
184
if (kind == " no" ) {
135
185
return (plot_list )
136
186
}
137
- plot_list <- add_titles(plot_list , nms = names(object )) # see sv_waterfall()
138
- patchwork :: wrap_plots(plot_list )
187
+ plot_list <- add_titles(plot_list , nms = names(object )) # see sv_waterfall()
188
+ p <- patchwork :: wrap_plots(plot_list , axis_titles = " collect" , guides = " collect" )
189
+
190
+ return (p )
139
191
}
0 commit comments