Description
Thanks for providing a great SHAP visualisation package for R!
I'm looking into fast ways to surface interaction effects in H2O GBMs. Unfortunately, unlike xgboost, H2O does not provide interaction SHAP values and hence shapviz
relies on a heuristic based on weighted squared Pearson correlation between the SHAP value and other features' values in its potential_interactions()
implementation. I think that's a reasonable approach, but it doesn't work well for unordered categorical features (where it converts them to their arbitrarily ordered factor level numbers using data.matrix()
).
A natural extension of what you are doing now, which I believe would be more appropriate for categorical features, would be to consider the R squared of a linear regression model of the SHAP values on each of the other feature. For continuous features, that would give you the exact same value you have now. For categorical features, that would be measuring the association between the unordered factor levels and the SHAP values in a way that's not constraint by the arbitrary feature level numbering.
If you want to implement that, lines 230-233 would have to be replaced by:
# Complicated case: we need to rely on R squared based heuristic
r_sq <- function(s, x) {
sapply(x,
function(x) {
tryCatch({
summary(stats::lm(s ~ x))$r.squared
}, error = function(e) {
return(NA)
})
})
}
Here's a full example using a public H2O data set:
library(shapviz)
library(h2o)
h2o.init()
# Import the prostate dataset into H2O:
prostate <- h2o.importFile("http://s3.amazonaws.com/h2o-public-test-data/smalldata/prostate/prostate.csv")
# Set the predictors and response; set the factors:
prostate$CAPSULE <- as.factor(prostate$CAPSULE)
prostate$RACE <- as.factor(prostate$RACE)
prostate$DPROS <- as.factor(prostate$DPROS)
prostate$DCAPS <- as.factor(prostate$DCAPS)
prostate$GLEASON <- as.factor(prostate$GLEASON)
predictors <- c("AGE", "RACE", "DPROS", "DCAPS", "PSA", "VOL", "GLEASON")
response <- "CAPSULE"
# Build and train the model:
pros_gbm <- h2o.gbm(x = predictors,
y = response,
nfolds = 5,
seed = 1111,
keep_cross_validation_predictions = TRUE,
training_frame = prostate)
# Create shapviz object
shp <- shapviz(pros_gbm, X_pred = prostate, X = as.data.frame(prostate))
# Replace correlation measure with R squared measure
potential_interactions_rsq <- function(obj, v) {
stopifnot(is.shapviz(obj))
S <- get_shap_values(obj)
S_inter <- get_shap_interactions(obj)
X <- get_feature_values(obj)
nms <- colnames(obj)
v_other <- setdiff(nms, v)
stopifnot(v %in% nms)
if (ncol(obj) <= 1L) {
return(NULL)
}
# Simple case: we have SHAP interaction values
if (!is.null(S_inter)) {
return(sort(2 * colMeans(abs(S_inter[, v, ]))[v_other], decreasing = TRUE))
}
# Complicated case: we need to rely on R squared based heuristic
r_sq <- function(s, x) {
sapply(x,
function(x) {
tryCatch({
summary(stats::lm(s ~ x))$r.squared
}, error = function(e) {
return(NA)
})
})
}
n_bins <- ceiling(min(sqrt(nrow(X)), nrow(X) / 20))
v_bin <- shapviz:::.fast_bin(X[[v]], n_bins = n_bins)
s_bin <- split(S[, v], v_bin)
X_bin <- split(X[v_other], v_bin)
w <- do.call(rbind, lapply(X_bin, function(z) colSums(!is.na(z))))
cor_squared <- do.call(rbind, mapply(r_sq, s_bin, X_bin, SIMPLIFY = FALSE))
sort(colSums(w * cor_squared, na.rm = TRUE) / colSums(w), decreasing = TRUE)
}
# Current implementation
potential_interactions(shp, v = "PSA")
#> GLEASON DPROS VOL DCAPS RACE AGE
#> 0.14827267 0.10383619 0.07988404 0.07166984 0.06715848 0.05922560
# Suggested implementation
potential_interactions_rsq(shp, v = "PSA")
#> GLEASON DPROS VOL RACE DCAPS AGE
#> 0.32998601 0.25517234 0.07988404 0.07827180 0.07166984 0.05922560
Created on 2023-10-24 with reprex v2.0.2