Skip to content

Commit d94381c

Browse files
authored
Merge pull request #137 from ModelOriented/imp-sort
sv_importance and sv_interaction receive a sort_features option.
2 parents 2713bda + 163ba46 commit d94381c

File tree

10 files changed

+73
-23
lines changed

10 files changed

+73
-23
lines changed

DESCRIPTION

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: shapviz
22
Title: SHAP Visualizations
3-
Version: 0.9.3
3+
Version: 0.9.4
44
Authors@R: c(
55
person("Michael", "Mayer", , "[email protected]", role = c("aut", "cre")),
66
person("Adrian", "Stando", , "[email protected]", role = "ctb")
@@ -21,7 +21,7 @@ Depends:
2121
R (>= 3.6.0)
2222
Encoding: UTF-8
2323
Roxygen: list(markdown = TRUE)
24-
RoxygenNote: 7.2.3
24+
RoxygenNote: 7.3.1
2525
Imports:
2626
ggfittext (>= 0.8.0),
2727
gggenes,

NEWS.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# shapviz 0.9.4
2+
3+
## Improvements
4+
5+
- New argument `sort_features = TRUE` in `sv_importance()` and `sv_interaction()`. Set to `FALSE` to show the features as they appear in your SHAP matrix. In that case, the plots will show the *first* `max_display` features, not the *most important* features. Implements #136.
6+
17
# shapviz 0.9.3
28

39
## `sv_dependence()`: Control over automatic color feature selection

R/sv_importance.R

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#' @param kind Should a "bar" plot (the default), a "beeswarm" plot, or "both" be shown?
1515
#' Set to "no" in order to suppress plotting. In that case, the sorted
1616
#' SHAP feature importances of all variables are returned.
17-
#' @param max_display Maximum number of features (with highest importance) to plot.
17+
#' @param max_display How many features should be plotted?
1818
#' Set to `Inf` to show all features. Has no effect if `kind = "no"`.
1919
#' @param fill Color used to fill the bars (only used if bars are shown).
2020
#' @param bar_width Relative width of the bars (only used if bars are shown).
@@ -38,6 +38,7 @@
3838
#' (only if `show_numbers = TRUE`). To change to scientific notation, use
3939
#' `function(x) = prettyNum(x, scientific = TRUE)`.
4040
#' @param number_size Text size of the numbers (if `show_numbers = TRUE`).
41+
#' @param sort_features Should features be sorted or not? The default is `TRUE`.
4142
#' @param ... Arguments passed to [ggplot2::geom_bar()] (if `kind = "bar"`) or to
4243
#' [ggplot2::geom_point()] otherwise. For instance, passing `alpha = 0.2` will produce
4344
#' semi-transparent beeswarms, and setting `size = 3` will produce larger dots.
@@ -75,10 +76,10 @@ sv_importance.shapviz <- function(object, kind = c("bar", "beeswarm", "both", "n
7576
viridis_args = getOption("shapviz.viridis_args"),
7677
color_bar_title = "Feature value",
7778
show_numbers = FALSE, format_fun = format_max,
78-
number_size = 3.2, ...) {
79+
number_size = 3.2, sort_features = TRUE, ...) {
7980
stopifnot("format_fun must be a function" = is.function(format_fun))
8081
kind <- match.arg(kind)
81-
imp <- .get_imp(get_shap_values(object))
82+
imp <- .get_imp(get_shap_values(object), sort_features = sort_features)
8283

8384
if (kind == "no") {
8485
return(imp)
@@ -162,13 +163,13 @@ sv_importance.mshapviz <- function(object, kind = c("bar", "beeswarm", "both", "
162163
viridis_args = getOption("shapviz.viridis_args"),
163164
color_bar_title = "Feature value",
164165
show_numbers = FALSE, format_fun = format_max,
165-
number_size = 3.2, ...) {
166+
number_size = 3.2, sort_features = TRUE, ...) {
166167
kind <- match.arg(kind)
167168
bar_type <- match.arg(bar_type)
168169

169170
# All other cases are done via {patchwork}
170171
if (kind %in% c("bar", "no") && bar_type != "separate") {
171-
imp <- .get_imp(get_shap_values(object))
172+
imp <- .get_imp(get_shap_values(object), sort_features = sort_features)
172173
if (kind == "no") {
173174
return(imp)
174175
}
@@ -223,6 +224,7 @@ sv_importance.mshapviz <- function(object, kind = c("bar", "beeswarm", "both", "
223224
show_numbers = show_numbers,
224225
format_fun = format_fun,
225226
number_size = number_size,
227+
sort_features = sort_features,
226228
...
227229
)
228230
if (kind == "no") {
@@ -243,13 +245,20 @@ sv_importance.mshapviz <- function(object, kind = c("bar", "beeswarm", "both", "
243245
(z - r[1L]) /(r[2L] - r[1L])
244246
}
245247

246-
.get_imp <- function(z) {
248+
.get_imp <- function(z, sort_features = TRUE) {
247249
if (is.matrix(z)) {
248-
return(sort(colMeans(abs(z)), decreasing = TRUE))
250+
imp <- colMeans(abs(z))
251+
if (sort_features) {
252+
imp <- sort(imp, decreasing = TRUE)
253+
}
254+
return(imp)
249255
}
250256
# list/mshapviz
251257
imp <- sapply(z, function(x) colMeans(abs(x)))
252-
imp[order(-rowSums(imp)), ]
258+
if (sort_features) {
259+
imp <- imp[order(-rowSums(imp)), ]
260+
}
261+
return(imp)
253262
}
254263

255264
.scale_X <- function(X) {

R/sv_interaction.R

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,13 @@ sv_interaction.shapviz <- function(object, kind = c("beeswarm", "no"),
4545
max_display = 7L, alpha = 0.3,
4646
bee_width = 0.3, bee_adjust = 0.5,
4747
viridis_args = getOption("shapviz.viridis_args"),
48-
color_bar_title = "Row feature value", ...) {
48+
color_bar_title = "Row feature value",
49+
sort_features = TRUE, ...) {
4950
kind <- match.arg(kind)
5051
if (is.null(get_shap_interactions(object))) {
5152
stop("No SHAP interaction values available.")
5253
}
53-
ord <- names(.get_imp(get_shap_values(object)))
54+
ord <- names(.get_imp(get_shap_values(object), sort_features = sort_features))
5455
object <- object[, ord]
5556

5657
if (kind == "no") {
@@ -112,7 +113,8 @@ sv_interaction.mshapviz <- function(object, kind = c("beeswarm", "no"),
112113
max_display = 7L, alpha = 0.3,
113114
bee_width = 0.3, bee_adjust = 0.5,
114115
viridis_args = getOption("shapviz.viridis_args"),
115-
color_bar_title = "Row feature value", ...) {
116+
color_bar_title = "Row feature value",
117+
sort_features = TRUE, ...) {
116118
kind <- match.arg(kind)
117119

118120
plot_list <- lapply(
@@ -126,6 +128,7 @@ sv_interaction.mshapviz <- function(object, kind = c("beeswarm", "no"),
126128
bee_adjust = bee_adjust,
127129
viridis_args = viridis_args,
128130
color_bar_title = color_bar_title,
131+
sort_features = sort_features,
129132
...
130133
)
131134
if (kind == "no") {

man/shapviz-package.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/sv_importance.Rd

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/sv_interaction.Rd

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packaging.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ library(usethis)
1515
use_description(
1616
fields = list(
1717
Title = "SHAP Visualizations",
18-
Version = "0.9.3",
18+
Version = "0.9.4",
1919
Description = "Visualizations for SHAP (SHapley Additive exPlanations),
2020
such as waterfall plots, force plots, various types of importance plots,
2121
dependence plots, and interaction plots.

tests/testthat/test-plots-mshapviz.R

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,20 +108,31 @@ test_that("plots work for non-syntactic column names", {
108108
)
109109
})
110110

111-
test_that("sv_importance() and sv_interaction() and kind = 'no' gives matrix", {
112-
X_pred <- data.matrix(iris[, -1L])
113-
dtrain <- xgboost::xgb.DMatrix(X_pred, label = iris[, 1L], nthread = 1)
114-
fit <- xgboost::xgb.train(params = list(nthread = 1L), data = dtrain, nrounds = 1L)
115-
x <- shapviz(fit, X_pred = X_pred, interactions = TRUE)
116-
x <- c(m1 = x, m2 = x)
111+
X_pred <- data.matrix(iris[, -1L])
112+
dtrain <- xgboost::xgb.DMatrix(X_pred, label = iris[, 1L], nthread = 1)
113+
fit <- xgboost::xgb.train(params = list(nthread = 1L), data = dtrain, nrounds = 1L)
114+
x <- shapviz(fit, X_pred = X_pred, interactions = TRUE)
115+
x <- c(m1 = x, m2 = x)
117116

117+
test_that("sv_importance() and sv_interaction() and kind = 'no' gives matrix", {
118118
imp <- sv_importance(x, kind = "no")
119119
expect_true(is.matrix(imp) && all(dim(imp) == c(4L, length(x))))
120120

121121
inter <- sv_interaction(x, kind = "no")
122122
expect_true(is.list(inter) && all(dim(inter[[1L]]) == rep(ncol(X_pred), 2L)))
123123
})
124124

125+
126+
test_that("sv_importance() and sv_interaction() respect sort_features = FALSE", {
127+
imp <- sv_importance(x, kind = "no", sort_features = FALSE)
128+
expect_true(all(rownames(imp) == colnames(x$m1)))
129+
130+
inter <- sv_interaction(x, kind = "no", sort_features = FALSE)
131+
expect_true(all(rownames(inter$m1) == colnames(x$m1)))
132+
})
133+
134+
135+
125136
test_that("sv_dependence() does not work with multiple v", {
126137
X_pred <- data.matrix(iris[, -1L])
127138
dtrain <- xgboost::xgb.DMatrix(X_pred, label = iris[, 1L], nthread = 1)

tests/testthat/test-plots-shapviz.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,16 @@ test_that("sv_importance() and sv_interaction() and kind = 'no' gives numeric ou
173173
expect_true(is.numeric(inter) && all(dim(inter) == rep(ncol(X_pred), 2L)))
174174
})
175175

176+
test_that("sv_importance() and sv_interaction() respect sort_features = FALSE", {
177+
X_pred <- data.matrix(iris[, -1L])
178+
dtrain <- xgboost::xgb.DMatrix(X_pred, label = iris[, 1L], nthread = 1)
179+
fit <- xgboost::xgb.train(params = list(nthread = 1L), data = dtrain, nrounds = 1L)
180+
x <- shapviz(fit, X_pred = X_pred, interactions = TRUE)
181+
182+
imp <- sv_importance(x, kind = "no", sort_features = FALSE)
183+
expect_true(all(names(imp) == colnames(x)))
184+
185+
inter <- sv_interaction(x, kind = "no", sort_features = FALSE)
186+
expect_true(all(names(inter) == colnames(x)))
187+
})
188+

0 commit comments

Comments
 (0)