diff --git a/R/fit.R b/R/fit.R index 17a07c37e..7bf6ec1b2 100644 --- a/R/fit.R +++ b/R/fit.R @@ -1398,8 +1398,11 @@ CmdStanMCMC <- R6::R6Class( #' but will result in a warning from the \pkg{loo} package. #' * If `r_eff` is anything else, that object will be passed as the `r_eff` #' argument to [loo::loo.array()]. +#' @param moment_match (boolean) Whether to use a moment-matching correction for +#' for problematic observations. #' @param ... Other arguments (e.g., `cores`, `save_psis`, etc.) passed to -#' [loo::loo.array()]. +#' [loo::loo.array()] or [loo::loo_moment_match.default()] +#' (if `moment_match` = `TRUE` is set). #' #' @return The object returned by [loo::loo.array()]. #' @@ -1416,7 +1419,7 @@ CmdStanMCMC <- R6::R6Class( #' print(loo_result) #' } #' -loo <- function(variables = "log_lik", r_eff = TRUE, ...) { +loo <- function(variables = "log_lik", r_eff = TRUE, moment_match = FALSE, ...) { require_suggested_package("loo") LLarray <- self$draws(variables, format = "draws_array") if (is.logical(r_eff)) { @@ -1427,7 +1430,39 @@ loo <- function(variables = "log_lik", r_eff = TRUE, ...) { r_eff <- NULL } } - loo::loo.array(LLarray, r_eff = r_eff, ...) + + if (moment_match == TRUE) { + # Moment-matching requires log-prob, constrain, and unconstrain methods + if (is.null(private$model_methods_env_$model_ptr)) { + self$init_model_methods() + } + + suppressWarnings(loo_result <- loo::loo.array(LLarray, r_eff = r_eff, ...)) + + log_lik_i <- function(x, i, parameter_name = "log_lik", ...) { + ll_array <- x$draws(variables = parameter_name, format = "draws_array")[,,i] + # draws_array types don't drop the last dimension when it's 1, so we do this manually + attr(ll_array, "dim") <- attributes(ll_array)$dim[1:2] + ll_array + } + + log_lik_i_upars <- function(x, upars, i, parameter_name = "log_lik", ...) { + apply(upars, 1, \(up_i) { x$constrain_variables(up_i)[[parameter_name]][i] }) + } + + loo::loo_moment_match.default( + x = self, + loo = loo_result, + post_draws = \(x, ...) { x$draws(format = "draws_matrix") }, + log_lik_i = log_lik_i, + unconstrain_pars = \(x, pars, ...) { do.call(rbind, lapply(x$unconstrain_draws(), \(chain) { do.call(rbind, chain) })) }, + log_prob_upars = \(x, upars, ...) { apply(upars, 1, x$log_prob) }, + log_lik_i_upars = log_lik_i_upars, + ... + ) + } else { + loo::loo.array(LLarray, r_eff = r_eff, ...) + } } CmdStanMCMC$set("public", name = "loo", value = loo) diff --git a/man/fit-method-loo.Rd b/man/fit-method-loo.Rd index d72c0ecb7..5bbb1492f 100644 --- a/man/fit-method-loo.Rd +++ b/man/fit-method-loo.Rd @@ -5,7 +5,7 @@ \alias{loo} \title{Leave-one-out cross-validation (LOO-CV)} \usage{ -loo(variables = "log_lik", r_eff = TRUE, ...) +loo(variables = "log_lik", r_eff = TRUE, moment_match = FALSE, ...) } \arguments{ \item{variables}{(character vector) The name(s) of the variable(s) in the @@ -23,8 +23,12 @@ but will result in a warning from the \pkg{loo} package. argument to \code{\link[loo:loo]{loo::loo.array()}}. }} +\item{moment_match}{(boolean) Whether to use a moment-matching correction for +for problematic observations.} + \item{...}{Other arguments (e.g., \code{cores}, \code{save_psis}, etc.) passed to -\code{\link[loo:loo]{loo::loo.array()}}.} +\code{\link[loo:loo]{loo::loo.array()}} or \code{\link[loo:loo_moment_match]{loo::loo_moment_match.default()}} +(if \code{moment_match} = \code{TRUE} is set).} } \value{ The object returned by \code{\link[loo:loo]{loo::loo.array()}}. diff --git a/tests/testthat/resources/data/loo_moment_match.data.json b/tests/testthat/resources/data/loo_moment_match.data.json new file mode 100644 index 000000000..0a2cdb801 --- /dev/null +++ b/tests/testthat/resources/data/loo_moment_match.data.json @@ -0,0 +1,272 @@ +{ + "N": 262, + "K": 3, + "x": [ + [17.549928774784245, 1, 0], + [18.200274723201296, 1, 0], + [1.2922847983320085, 1, 0], + [1.7320508075688772, 1, 0], + [1.4142135623730951, 1, 0], + [0, 1, 0], + [8.3666002653407556, 1, 0], + [8.0349237706402672, 1, 0], + [1, 0, 0], + [3.7416573867739413, 0, 0], + [11.757976016304847, 0, 0], + [4, 0, 0], + [9.8488578017961039, 0, 0], + [9.8994949366116654, 0, 0], + [6.6332495807107996, 0, 0], + [21.213203435596427, 0, 0], + [6.0555759428810738, 0, 0], + [8.6602540378443873, 0, 0], + [1.4142135623730951, 0, 0], + [18.506755523321747, 0, 0], + [4.7434164902525691, 0, 0], + [2, 0, 0], + [9.6953597148326587, 0, 0], + [11.494781424629178, 0, 0], + [8.9442719099991592, 0, 0], + [9.7467943448089631, 0, 0], + [5.8420886675914119, 0, 0], + [4.358898943540674, 0, 0], + [5, 0, 0], + [7.5828754440515507, 0, 0], + [1.7635192088548397, 0, 0], + [3.2403703492039302, 0, 0], + [6.8556546004010439, 0, 0], + [1.6217274740226855, 0, 0], + [14.282856857085701, 0, 0], + [0, 0, 0], + [0, 0, 0], + [15.692354826475215, 0, 0], + [8.730406634286858, 0, 0], + [3.872983346207417, 0, 0], + [3.8574603043971822, 0, 0], + [7.1414284285428504, 0, 0], + [3.872983346207417, 0, 0], + [7.245688373094719, 0, 0], + [3.7080992435478315, 0, 0], + [1.0816653826391966, 0, 0], + [1, 0, 0], + [1.4142135623730951, 0, 0], + [0, 0, 0], + [1.4142135623730951, 0, 0], + [1.4142135623730951, 0, 0], + [5.4772255750516612, 0, 0], + [2.6457513110645907, 0, 0], + [4.2426406871192848, 0, 0], + [1.4142135623730951, 0, 0], + [16.30950643030009, 0, 0], + [13.19090595827292, 1, 0], + [15.874507866387544, 1, 0], + [0.93808315196468595, 1, 0], + [19.300259065618782, 1, 0], + [6.4807406984078604, 1, 0], + [4.358898943540674, 1, 0], + [16.217274740226856, 1, 0], + [5.8309518948453007, 1, 0], + [1.3228756555322954, 1, 0], + [14.611639196202457, 1, 0], + [3.6235341863986879, 1, 0], + [12.409673645990857, 1, 0], + [4.5825756949558398, 1, 0], + [14.832396974191326, 1, 0], + [6.1903150162168643, 1, 0], + [18.774983355518586, 1, 0], + [2.6457513110645907, 1, 0], + [2, 1, 0], + [7.713624310270756, 1, 0], + [2.7386127875258306, 1, 0], + [10.62449998823474, 1, 0], + [13.114877048604001, 1, 0], + [3.6055512754639891, 1, 0], + [4.2426406871192848, 1, 0], + [0, 1, 0], + [5.196152422706632, 1, 0], + [12.165525060596439, 1, 0], + [5.6568542494923806, 1, 0], + [5.2915026221291814, 1, 0], + [0, 1, 0], + [5.2915026221291814, 1, 0], + [0, 1, 1], + [3.7416573867739413, 1, 1], + [2.2360679774997898, 1, 1], + [0, 1, 1], + [10.198039027185569, 1, 1], + [5.196152422706632, 1, 1], + [11.489125293076057, 1, 1], + [16.06237840420901, 1, 1], + [1, 1, 1], + [1.4142135623730951, 1, 1], + [2.4494897427831779, 1, 1], + [1.7320508075688772, 1, 1], + [1.7320508075688772, 1, 1], + [0, 1, 1], + [0, 1, 1], + [0, 1, 1], + [1.1180339887498949, 1, 1], + [0, 1, 1], + [4, 1, 1], + [8.2462112512353212, 1, 1], + [1, 1, 1], + [4.2426406871192848, 1, 1], + [11.120701416727274, 1, 1], + [1.4142135623730951, 1, 1], + [9.0829510622924747, 1, 1], + [0, 1, 1], + [1.1180339887498949, 1, 1], + [13.076696830622021, 1, 1], + [9.5854055730574075, 1, 1], + [2.2360679774997898, 1, 1], + [2.6457513110645907, 1, 1], + [0, 1, 1], + [2, 1, 1], + [7.3314391493075899, 1, 1], + [11.749893616539683, 1, 1], + [1, 1, 1], + [1.1180339887498949, 1, 1], + [5.2915026221291814, 1, 1], + [3.872983346207417, 1, 1], + [0.93808315196468595, 1, 1], + [0, 1, 1], + [5.7445626465380286, 1, 1], + [11.672617529928752, 1, 1], + [11.291589790636214, 1, 1], + [1.4142135623730951, 1, 1], + [1.4142135623730951, 1, 1], + [0, 1, 1], + [1.7320508075688772, 1, 1], + [6.7823299831252681, 1, 1], + [8.2462112512353212, 1, 1], + [0, 1, 1], + [7, 1, 1], + [5.2086466572421672, 1, 1], + [6.7453687816160208, 1, 0], + [0, 1, 0], + [2, 1, 0], + [0, 1, 0], + [1, 1, 0], + [0, 1, 0], + [0, 1, 0], + [3.2403703492039302, 1, 0], + [0, 1, 0], + [1.7320508075688772, 1, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [1.9364916731037085, 1, 0], + [0, 1, 0], + [0, 1, 0], + [0, 0, 0], + [0.93808315196468595, 0, 0], + [1.4142135623730951, 1, 0], + [9, 1, 0], + [5.5677643628300215, 1, 0], + [3.1622776601683795, 1, 0], + [3.1984371183438953, 1, 0], + [0, 0, 0], + [3.6055512754639891, 1, 0], + [1, 1, 0], + [5.7662812973353983, 1, 0], + [0, 0, 0], + [7.2801098892805181, 0, 0], + [2.2360679774997898, 1, 0], + [12.529964086141668, 1, 0], + [4.8301138702933288, 1, 0], + [2.4494897427831779, 1, 0], + [3.1622776601683795, 1, 0], + [10, 1, 0], + [7.416198487095663, 1, 0], + [0, 1, 0], + [4.0311288741492746, 1, 0], + [2.7892651361962706, 1, 0], + [7.2801098892805181, 1, 0], + [1.4142135623730951, 1, 0], + [8.5440037453175304, 1, 0], + [0, 1, 0], + [0, 1, 0], + [0, 1, 0], + [1.7832554500127009, 1, 0], + [2.2360679774997898, 0, 0], + [1.7320508075688772, 0, 0], + [0, 0, 0], + [3.1622776601683795, 0, 0], + [1, 0, 0], + [0, 1, 0], + [1, 1, 0], + [3.4641016151377544, 1, 0], + [4.1231056256176606, 1, 0], + [4, 1, 0], + [1.5811388300841898, 1, 0], + [0, 1, 0], + [4.6776062254105994, 1, 0], + [13.152946437965905, 1, 0], + [10.535653752852738, 1, 0], + [5.9160797830996161, 1, 0], + [0, 1, 0], + [1.7320508075688772, 1, 0], + [1.4491376746189439, 1, 0], + [0, 1, 0], + [7, 1, 0], + [1.1180339887498949, 1, 0], + [1, 1, 0], + [1, 1, 0], + [0, 1, 0], + [0, 1, 0], + [7.3484692283495345, 1, 0], + [2.0493901531919199, 1, 0], + [7.1589105316381767, 1, 0], + [5.4772255750516612, 1, 0], + [14, 0, 0], + [0, 0, 0], + [1.4142135623730951, 0, 0], + [1.2247448713915889, 0, 0], + [9.8107084351742913, 0, 0], + [15.540270267920054, 0, 0], + [11.832159566199232, 0, 0], + [4.2426406871192848, 0, 0], + [1.7320508075688772, 0, 0], + [0.73484692283495345, 0, 0], + [9.0553851381374173, 0, 0], + [4.358898943540674, 0, 0], + [4.3301270189221936, 0, 0], + [7.1414284285428504, 0, 0], + [0, 0, 0], + [1, 0, 0], + [0, 0, 0], + [2.3323807579381204, 0, 0], + [0, 0, 0], + [1.7320508075688772, 0, 1], + [0, 0, 1], + [0, 0, 1], + [0, 0, 1], + [0, 0, 1], + [5.3619026473818039, 0, 1], + [0, 0, 1], + [1.7320508075688772, 0, 1], + [1.4142135623730951, 0, 1], + [11.61895003862225, 0, 1], + [0, 0, 1], + [8.2613558209291522, 0, 1], + [1, 0, 1], + [0, 0, 1], + [0, 0, 1], + [0, 0, 1], + [1, 0, 1], + [1, 0, 1], + [1.5811388300841898, 0, 1], + [7.1589105316381767, 0, 1], + [3.6235341863986879, 0, 1], + [0, 0, 1], + [0, 0, 1], + [0, 0, 1], + [0, 0, 1], + [0, 0, 1], + [0, 0, 1] + ], + "y": [153, 127, 7, 7, 0, 0, 73, 24, 2, 2, 0, 21, 0, 179, 136, 104, 2, 5, 1, 203, 32, 1, 135, 59, 29, 120, 44, 1, 2, 193, 13, 37, 2, 0, 3, 0, 0, 15, 11, 19, 0, 19, 4, 122, 48, 0, 0, 3, 0, 9, 0, 0, 0, 12, 0, 357, 11, 60, 0, 159, 50, 48, 178, 4, 6, 0, 33, 127, 4, 63, 88, 5, 0, 0, 62, 4, 150, 38, 0, 3, 1, 14, 77, 42, 21, 1, 45, 0, 0, 0, 0, 0, 183, 28, 49, 1, 0, 0, 3, 0, 0, 0, 0, 18, 0, 0, 5, 0, 19, 5, 0, 27, 0, 0, 77, 1, 3, 2, 0, 0, 22, 102, 0, 0, 0, 0, 0, 0, 0, 4, 12, 2, 0, 0, 1, 0, 40, 0, 1, 2, 27, 0, 2, 0, 0, 0, 0, 3, 1, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 53, 69, 15, 0, 2, 4, 6, 8, 0, 0, 0, 18, 38, 0, 2, 18, 34, 1, 109, 5, 15, 0, 64, 0, 1, 0, 1, 3, 5, 7, 18, 1, 0, 0, 3, 3, 0, 19, 0, 8, 26, 50, 15, 0, 19, 5, 17, 121, 1, 0, 0, 0, 0, 4, 1, 14, 1, 25, 0, 14, 0, 59, 243, 80, 69, 14, 9, 38, 37, 48, 293, 7, 10, 19, 24, 91, 1, 0, 0, 0, 0, 148, 3, 26, 12, 77, 0, 7, 0, 1, 0, 17, 0, 7, 11, 6, 50, 1, 0, 0, 0, 171, 8], + "outcome_offset": [-0.22314355131420971, -0.51082562376599072, 0, 0, 0.13353139262452005, 0, -0.22314355131420971, 0.13353139262452005, 0, 0.13353139262452005, 0, 0, 0, -0.22314355131420971, 0, -0.22314355131420971, -0.22314355131420971, 0, 0, 0, 0, 0, 0, -0.1541506798272585, 0, 0, -0.22314355131420971, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -0.22314355131420971, 0.45198512374305638, 0.35667494393873334, 0, 0, 0, -0.22314355131420971, 0, 0, 0, 0, 0, 0, 0.13353139262452005, 0.13353139262452005, 0, 0, -0.22314355131420971, 0, -0.089612158689687402, 0, 0.25131442828090944, 0, 0, -0.25951119548508511, 0, 0.35667494393873334, 0, -0.25951119548508511, 0, 0, 0, 0, 0, -0.51082562376599072, 0, 0, 0, 0, 0.25131442828090944, 0, 0, 0, 0, 0, 0, 0.13353139262452005, 0, 0, 0, -0.22314355131420971, -0.22314355131420971, 0, 0, -0.37729423114146754, 0, 0, 0, 0, 0, 0, -0.22314355131420971, -0.1541506798272585, 0, 0, 0, 0, 0, 0, -0.916290731874155, -0.22314355131420971, 0.13353139262452005, 0, -0.22314355131420971, 0, -1.6094379124341003, 0, 0, 0, 0, 0, 0, 0, 0, 0.13353139262452005, 0.13353139262452005, 0, -0.22314355131420971, 0, -0.22314355131420971, 0, 0, 1.4552872326068431, -0.22314355131420971, -0.22314355131420971, -0.22314355131420971, 0, 0, 0.25131442828090944, 0, 0, 0, 0, 0, -0.7827593392496327, 0, 0.88730319500090338, 0.35667494393873334, 0, 0, 0.13353139262452005, 0, 0.13353139262452005, 0, 0, 0, -0.22314355131420971, 0.13353139262452005, 0, -0.22314355131420971, 0.45198512374305638, 0, 0.13353139262452005, 0.8266785731844698, 0, -0.55961578793542355, -0.22314355131420971, -0.1541506798272585, -0.22314355131420971, -0.1541506798272585, 0.8266785731844698, -0.22314355131420971, 0, -0.1541506798272585, 0, 0.028170876966697733, -0.51082562376599072, -0.1541506798272585, 0, -0.22314355131420971, 0, 0.39589565709201657, 0, 0, 0, -0.22314355131420971, -0.22314355131420971, 0, 0, -0.1541506798272585, -0.22314355131420971, 0, -0.22314355131420971, 0, 0, -0.22314355131420971, 0, 0, 0, 0.35667494393873334, 0.53899650073268446, -0.25951119548508511, -0.22314355131420971, 0, 0.61903920840622506, 0, 0, 0, 0, 0, 0, -0.22314355131420971, 0, 0, 0.8266785731844698, -0.22314355131420971, 0.35667494393873334, 0.028170876966697733, 0.53899650073268446, 0.13353139262452005, 0.13353139262452005, 0, 0, 0.13353139262452005, -0.22314355131420971, 0.35667494393873334, -0.089612158689687402, 0.13353139262452005, 0.25131442828090944, -0.51082562376599072, 0, -0.22314355131420971, 0.13353139262452005, 0, 0.35667494393873334, 0, 0, 0, 0, 0, 0, 0, 0, 0.13353139262452005, -0.1541506798272585, 0, 0, 0, 0, 0, 0, 0, -0.51082562376599072, 0.028170876966697733, 0, 0, -0.37729423114146754, -0.22314355131420971, 0, -0.22314355131420971, 0.39589565709201657, 0, 0, 0, 0], + "beta_prior_scale": 2.5, + "alpha_prior_scale": 5 +} diff --git a/tests/testthat/resources/stan/loo_moment_match.stan b/tests/testthat/resources/stan/loo_moment_match.stan new file mode 100644 index 000000000..cd28fd942 --- /dev/null +++ b/tests/testthat/resources/stan/loo_moment_match.stan @@ -0,0 +1,24 @@ +data { + int K; + int N; + matrix[N,K] x; + array[N] int y; + vector[N] outcome_offset; + + real beta_prior_scale; + real alpha_prior_scale; +} +parameters { + vector[K] beta; + real intercept; +} +model { + y ~ poisson(exp(x * beta + intercept + outcome_offset)); + beta ~ normal(0,beta_prior_scale); + intercept ~ normal(0,alpha_prior_scale); +} +generated quantities { + vector[N] log_lik; + for (n in 1:N) + log_lik[n] = poisson_lpmf(y[n] | exp(x[n] * beta + intercept + outcome_offset[n])); +} diff --git a/tests/testthat/test-fit-mcmc.R b/tests/testthat/test-fit-mcmc.R index 1b949ef74..fa33b70b6 100644 --- a/tests/testthat/test-fit-mcmc.R +++ b/tests/testthat/test-fit-mcmc.R @@ -272,6 +272,29 @@ test_that("loo method works if log_lik is available", { expect_s3_class(suppressWarnings(fit_bernoulli$loo(r_eff = FALSE)), "loo") }) +test_that("loo method works with moment-matching", { + skip_if_not_installed("loo") + skip_if(os_is_wsl()) + + # Moment-matching needs model-methods, so make sure hpp is available + mod <- cmdstan_model(testing_stan_file("loo_moment_match"), force_recompile = TRUE) + data_list <- testing_data("loo_moment_match") + fit <- mod$sample(data = data_list, chains = 1) + + # Regular LOO should warn that some pareto-k are "too high" + expect_warning(fit$loo(), + "Some Pareto k diagnostic values are too high. See help('pareto-k-diagnostic') for details.", + fixed=TRUE) + + # After moment-matching the warning should be downgraded to "slightly high" + expect_warning(fit$loo(moment_match = TRUE), + "Some Pareto k diagnostic values are slightly high. See help('pareto-k-diagnostic') for details.", + fixed=TRUE) + + # After moment-matching with lower target threshold there should be no warning + expect_no_warning(fit$loo(moment_match = TRUE, k_threshold=0.4)) +}) + test_that("loo errors if it can't find log lik variables", { skip_if_not_installed("loo") fit_schools <- testing_fit("schools")