diff --git a/R/shapviz.R b/R/shapviz.R index 2f501fb..e2646ae 100644 --- a/R/shapviz.R +++ b/R/shapviz.R @@ -410,10 +410,10 @@ shapviz.predict_parts <- function(object, ...) { #' @describeIn shapviz #' Creates a "shapviz" object from `shapr::explain()`. #' @export -shapviz.shapr <- function(object, X = object[["x_test"]], collapse = NULL, ...) { - dt <- as.matrix(object[["dt"]]) +shapviz.shapr <- function(object, X = as.data.frame(object$internal$data$x_explain), collapse = NULL, ...) { + dt <- as.matrix(object[["shapley_values_est"]]) shapviz.matrix( - object = dt[, setdiff(colnames(dt), "none"), drop = FALSE], + object = dt[, setdiff(colnames(dt), c("none","explain_id")), drop = FALSE], X = X, baseline = dt[1L, "none"], collapse = collapse diff --git a/man/shapviz.Rd b/man/shapviz.Rd index 376cb50..510cb76 100644 --- a/man/shapviz.Rd +++ b/man/shapviz.Rd @@ -46,7 +46,12 @@ shapviz(object, ...) \method{shapviz}{predict_parts}(object, ...) -\method{shapviz}{shapr}(object, X = object[["x_test"]], collapse = NULL, ...) +\method{shapviz}{shapr}( + object, + X = as.data.frame(object$internal$data$x_explain), + collapse = NULL, + ... +) \method{shapviz}{kernelshap}(object, X = object[["X"]], which_class = NULL, collapse = NULL, ...) diff --git a/vignettes/basic_use.Rmd b/vignettes/basic_use.Rmd index ad30eab..71477de 100644 --- a/vignettes/basic_use.Rmd +++ b/vignettes/basic_use.Rmd @@ -163,10 +163,14 @@ library(shapviz) library(shapr) fit <- lm(Sepal.Length ~ ., data = iris) -x <- shapr(iris, fit) explanation <- shapr::explain( - iris, approach = "ctree", explainer = x, prediction_zero = mean(iris$Sepal.Length) + model = fit, + x_train = iris[-1], + x_explain = iris[-1], + approach = "ctree", + phi0 = mean(iris$Sepal.Length) ) + shp <- shapviz(explanation) sv_importance(shp) sv_dependence(shp, "Sepal.Width")