Skip to content

Commit e2a686e

Browse files
authored
Merge a6eb030 into 83bbc35
2 parents 83bbc35 + a6eb030 commit e2a686e

File tree

5 files changed

+363
-5
lines changed

5 files changed

+363
-5
lines changed

R/fit.R

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,8 +1398,11 @@ CmdStanMCMC <- R6::R6Class(
13981398
#' but will result in a warning from the \pkg{loo} package.
13991399
#' * If `r_eff` is anything else, that object will be passed as the `r_eff`
14001400
#' argument to [loo::loo.array()].
1401+
#' @param moment_match (boolean) Whether to use a moment-matching correction for
1402+
#' for problematic observations.
14011403
#' @param ... Other arguments (e.g., `cores`, `save_psis`, etc.) passed to
1402-
#' [loo::loo.array()].
1404+
#' [loo::loo.array()] or [loo::loo_moment_match.default()]
1405+
#' (if `moment_match` = `TRUE` is set).
14031406
#'
14041407
#' @return The object returned by [loo::loo.array()].
14051408
#'
@@ -1416,7 +1419,7 @@ CmdStanMCMC <- R6::R6Class(
14161419
#' print(loo_result)
14171420
#' }
14181421
#'
1419-
loo <- function(variables = "log_lik", r_eff = TRUE, ...) {
1422+
loo <- function(variables = "log_lik", r_eff = TRUE, moment_match = FALSE, ...) {
14201423
require_suggested_package("loo")
14211424
LLarray <- self$draws(variables, format = "draws_array")
14221425
if (is.logical(r_eff)) {
@@ -1427,7 +1430,39 @@ loo <- function(variables = "log_lik", r_eff = TRUE, ...) {
14271430
r_eff <- NULL
14281431
}
14291432
}
1430-
loo::loo.array(LLarray, r_eff = r_eff, ...)
1433+
1434+
if (moment_match == TRUE) {
1435+
# Moment-matching requires log-prob, constrain, and unconstrain methods
1436+
if (is.null(private$model_methods_env_$model_ptr)) {
1437+
self$init_model_methods()
1438+
}
1439+
1440+
suppressWarnings(loo_result <- loo::loo.array(LLarray, r_eff = r_eff, ...))
1441+
1442+
log_lik_i <- function(x, i, parameter_name = "log_lik", ...) {
1443+
ll_array <- x$draws(variables = parameter_name, format = "draws_array")[,,i]
1444+
# draws_array types don't drop the last dimension when it's 1, so we do this manually
1445+
attr(ll_array, "dim") <- attributes(ll_array)$dim[1:2]
1446+
ll_array
1447+
}
1448+
1449+
log_lik_i_upars <- function(x, upars, i, parameter_name = "log_lik", ...) {
1450+
apply(upars, 1, \(up_i) { x$constrain_variables(up_i)[[parameter_name]][i] })
1451+
}
1452+
1453+
loo::loo_moment_match.default(
1454+
x = self,
1455+
loo = loo_result,
1456+
post_draws = \(x, ...) { x$draws(format = "draws_matrix") },
1457+
log_lik_i = log_lik_i,
1458+
unconstrain_pars = \(x, pars, ...) { do.call(rbind, lapply(x$unconstrain_draws(), \(chain) { do.call(rbind, chain) })) },
1459+
log_prob_upars = \(x, upars, ...) { apply(upars, 1, x$log_prob) },
1460+
log_lik_i_upars = log_lik_i_upars,
1461+
...
1462+
)
1463+
} else {
1464+
loo::loo.array(LLarray, r_eff = r_eff, ...)
1465+
}
14311466
}
14321467
CmdStanMCMC$set("public", name = "loo", value = loo)
14331468

man/fit-method-loo.Rd

Lines changed: 6 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
{
2+
"N": 262,
3+
"K": 3,
4+
"x": [
5+
[17.549928774784245, 1, 0],
6+
[18.200274723201296, 1, 0],
7+
[1.2922847983320085, 1, 0],
8+
[1.7320508075688772, 1, 0],
9+
[1.4142135623730951, 1, 0],
10+
[0, 1, 0],
11+
[8.3666002653407556, 1, 0],
12+
[8.0349237706402672, 1, 0],
13+
[1, 0, 0],
14+
[3.7416573867739413, 0, 0],
15+
[11.757976016304847, 0, 0],
16+
[4, 0, 0],
17+
[9.8488578017961039, 0, 0],
18+
[9.8994949366116654, 0, 0],
19+
[6.6332495807107996, 0, 0],
20+
[21.213203435596427, 0, 0],
21+
[6.0555759428810738, 0, 0],
22+
[8.6602540378443873, 0, 0],
23+
[1.4142135623730951, 0, 0],
24+
[18.506755523321747, 0, 0],
25+
[4.7434164902525691, 0, 0],
26+
[2, 0, 0],
27+
[9.6953597148326587, 0, 0],
28+
[11.494781424629178, 0, 0],
29+
[8.9442719099991592, 0, 0],
30+
[9.7467943448089631, 0, 0],
31+
[5.8420886675914119, 0, 0],
32+
[4.358898943540674, 0, 0],
33+
[5, 0, 0],
34+
[7.5828754440515507, 0, 0],
35+
[1.7635192088548397, 0, 0],
36+
[3.2403703492039302, 0, 0],
37+
[6.8556546004010439, 0, 0],
38+
[1.6217274740226855, 0, 0],
39+
[14.282856857085701, 0, 0],
40+
[0, 0, 0],
41+
[0, 0, 0],
42+
[15.692354826475215, 0, 0],
43+
[8.730406634286858, 0, 0],
44+
[3.872983346207417, 0, 0],
45+
[3.8574603043971822, 0, 0],
46+
[7.1414284285428504, 0, 0],
47+
[3.872983346207417, 0, 0],
48+
[7.245688373094719, 0, 0],
49+
[3.7080992435478315, 0, 0],
50+
[1.0816653826391966, 0, 0],
51+
[1, 0, 0],
52+
[1.4142135623730951, 0, 0],
53+
[0, 0, 0],
54+
[1.4142135623730951, 0, 0],
55+
[1.4142135623730951, 0, 0],
56+
[5.4772255750516612, 0, 0],
57+
[2.6457513110645907, 0, 0],
58+
[4.2426406871192848, 0, 0],
59+
[1.4142135623730951, 0, 0],
60+
[16.30950643030009, 0, 0],
61+
[13.19090595827292, 1, 0],
62+
[15.874507866387544, 1, 0],
63+
[0.93808315196468595, 1, 0],
64+
[19.300259065618782, 1, 0],
65+
[6.4807406984078604, 1, 0],
66+
[4.358898943540674, 1, 0],
67+
[16.217274740226856, 1, 0],
68+
[5.8309518948453007, 1, 0],
69+
[1.3228756555322954, 1, 0],
70+
[14.611639196202457, 1, 0],
71+
[3.6235341863986879, 1, 0],
72+
[12.409673645990857, 1, 0],
73+
[4.5825756949558398, 1, 0],
74+
[14.832396974191326, 1, 0],
75+
[6.1903150162168643, 1, 0],
76+
[18.774983355518586, 1, 0],
77+
[2.6457513110645907, 1, 0],
78+
[2, 1, 0],
79+
[7.713624310270756, 1, 0],
80+
[2.7386127875258306, 1, 0],
81+
[10.62449998823474, 1, 0],
82+
[13.114877048604001, 1, 0],
83+
[3.6055512754639891, 1, 0],
84+
[4.2426406871192848, 1, 0],
85+
[0, 1, 0],
86+
[5.196152422706632, 1, 0],
87+
[12.165525060596439, 1, 0],
88+
[5.6568542494923806, 1, 0],
89+
[5.2915026221291814, 1, 0],
90+
[0, 1, 0],
91+
[5.2915026221291814, 1, 0],
92+
[0, 1, 1],
93+
[3.7416573867739413, 1, 1],
94+
[2.2360679774997898, 1, 1],
95+
[0, 1, 1],
96+
[10.198039027185569, 1, 1],
97+
[5.196152422706632, 1, 1],
98+
[11.489125293076057, 1, 1],
99+
[16.06237840420901, 1, 1],
100+
[1, 1, 1],
101+
[1.4142135623730951, 1, 1],
102+
[2.4494897427831779, 1, 1],
103+
[1.7320508075688772, 1, 1],
104+
[1.7320508075688772, 1, 1],
105+
[0, 1, 1],
106+
[0, 1, 1],
107+
[0, 1, 1],
108+
[1.1180339887498949, 1, 1],
109+
[0, 1, 1],
110+
[4, 1, 1],
111+
[8.2462112512353212, 1, 1],
112+
[1, 1, 1],
113+
[4.2426406871192848, 1, 1],
114+
[11.120701416727274, 1, 1],
115+
[1.4142135623730951, 1, 1],
116+
[9.0829510622924747, 1, 1],
117+
[0, 1, 1],
118+
[1.1180339887498949, 1, 1],
119+
[13.076696830622021, 1, 1],
120+
[9.5854055730574075, 1, 1],
121+
[2.2360679774997898, 1, 1],
122+
[2.6457513110645907, 1, 1],
123+
[0, 1, 1],
124+
[2, 1, 1],
125+
[7.3314391493075899, 1, 1],
126+
[11.749893616539683, 1, 1],
127+
[1, 1, 1],
128+
[1.1180339887498949, 1, 1],
129+
[5.2915026221291814, 1, 1],
130+
[3.872983346207417, 1, 1],
131+
[0.93808315196468595, 1, 1],
132+
[0, 1, 1],
133+
[5.7445626465380286, 1, 1],
134+
[11.672617529928752, 1, 1],
135+
[11.291589790636214, 1, 1],
136+
[1.4142135623730951, 1, 1],
137+
[1.4142135623730951, 1, 1],
138+
[0, 1, 1],
139+
[1.7320508075688772, 1, 1],
140+
[6.7823299831252681, 1, 1],
141+
[8.2462112512353212, 1, 1],
142+
[0, 1, 1],
143+
[7, 1, 1],
144+
[5.2086466572421672, 1, 1],
145+
[6.7453687816160208, 1, 0],
146+
[0, 1, 0],
147+
[2, 1, 0],
148+
[0, 1, 0],
149+
[1, 1, 0],
150+
[0, 1, 0],
151+
[0, 1, 0],
152+
[3.2403703492039302, 1, 0],
153+
[0, 1, 0],
154+
[1.7320508075688772, 1, 0],
155+
[0, 1, 0],
156+
[0, 1, 0],
157+
[0, 1, 0],
158+
[1.9364916731037085, 1, 0],
159+
[0, 1, 0],
160+
[0, 1, 0],
161+
[0, 0, 0],
162+
[0.93808315196468595, 0, 0],
163+
[1.4142135623730951, 1, 0],
164+
[9, 1, 0],
165+
[5.5677643628300215, 1, 0],
166+
[3.1622776601683795, 1, 0],
167+
[3.1984371183438953, 1, 0],
168+
[0, 0, 0],
169+
[3.6055512754639891, 1, 0],
170+
[1, 1, 0],
171+
[5.7662812973353983, 1, 0],
172+
[0, 0, 0],
173+
[7.2801098892805181, 0, 0],
174+
[2.2360679774997898, 1, 0],
175+
[12.529964086141668, 1, 0],
176+
[4.8301138702933288, 1, 0],
177+
[2.4494897427831779, 1, 0],
178+
[3.1622776601683795, 1, 0],
179+
[10, 1, 0],
180+
[7.416198487095663, 1, 0],
181+
[0, 1, 0],
182+
[4.0311288741492746, 1, 0],
183+
[2.7892651361962706, 1, 0],
184+
[7.2801098892805181, 1, 0],
185+
[1.4142135623730951, 1, 0],
186+
[8.5440037453175304, 1, 0],
187+
[0, 1, 0],
188+
[0, 1, 0],
189+
[0, 1, 0],
190+
[1.7832554500127009, 1, 0],
191+
[2.2360679774997898, 0, 0],
192+
[1.7320508075688772, 0, 0],
193+
[0, 0, 0],
194+
[3.1622776601683795, 0, 0],
195+
[1, 0, 0],
196+
[0, 1, 0],
197+
[1, 1, 0],
198+
[3.4641016151377544, 1, 0],
199+
[4.1231056256176606, 1, 0],
200+
[4, 1, 0],
201+
[1.5811388300841898, 1, 0],
202+
[0, 1, 0],
203+
[4.6776062254105994, 1, 0],
204+
[13.152946437965905, 1, 0],
205+
[10.535653752852738, 1, 0],
206+
[5.9160797830996161, 1, 0],
207+
[0, 1, 0],
208+
[1.7320508075688772, 1, 0],
209+
[1.4491376746189439, 1, 0],
210+
[0, 1, 0],
211+
[7, 1, 0],
212+
[1.1180339887498949, 1, 0],
213+
[1, 1, 0],
214+
[1, 1, 0],
215+
[0, 1, 0],
216+
[0, 1, 0],
217+
[7.3484692283495345, 1, 0],
218+
[2.0493901531919199, 1, 0],
219+
[7.1589105316381767, 1, 0],
220+
[5.4772255750516612, 1, 0],
221+
[14, 0, 0],
222+
[0, 0, 0],
223+
[1.4142135623730951, 0, 0],
224+
[1.2247448713915889, 0, 0],
225+
[9.8107084351742913, 0, 0],
226+
[15.540270267920054, 0, 0],
227+
[11.832159566199232, 0, 0],
228+
[4.2426406871192848, 0, 0],
229+
[1.7320508075688772, 0, 0],
230+
[0.73484692283495345, 0, 0],
231+
[9.0553851381374173, 0, 0],
232+
[4.358898943540674, 0, 0],
233+
[4.3301270189221936, 0, 0],
234+
[7.1414284285428504, 0, 0],
235+
[0, 0, 0],
236+
[1, 0, 0],
237+
[0, 0, 0],
238+
[2.3323807579381204, 0, 0],
239+
[0, 0, 0],
240+
[1.7320508075688772, 0, 1],
241+
[0, 0, 1],
242+
[0, 0, 1],
243+
[0, 0, 1],
244+
[0, 0, 1],
245+
[5.3619026473818039, 0, 1],
246+
[0, 0, 1],
247+
[1.7320508075688772, 0, 1],
248+
[1.4142135623730951, 0, 1],
249+
[11.61895003862225, 0, 1],
250+
[0, 0, 1],
251+
[8.2613558209291522, 0, 1],
252+
[1, 0, 1],
253+
[0, 0, 1],
254+
[0, 0, 1],
255+
[0, 0, 1],
256+
[1, 0, 1],
257+
[1, 0, 1],
258+
[1.5811388300841898, 0, 1],
259+
[7.1589105316381767, 0, 1],
260+
[3.6235341863986879, 0, 1],
261+
[0, 0, 1],
262+
[0, 0, 1],
263+
[0, 0, 1],
264+
[0, 0, 1],
265+
[0, 0, 1],
266+
[0, 0, 1]
267+
],
268+
"y": [153, 127, 7, 7, 0, 0, 73, 24, 2, 2, 0, 21, 0, 179, 136, 104, 2, 5, 1, 203, 32, 1, 135, 59, 29, 120, 44, 1, 2, 193, 13, 37, 2, 0, 3, 0, 0, 15, 11, 19, 0, 19, 4, 122, 48, 0, 0, 3, 0, 9, 0, 0, 0, 12, 0, 357, 11, 60, 0, 159, 50, 48, 178, 4, 6, 0, 33, 127, 4, 63, 88, 5, 0, 0, 62, 4, 150, 38, 0, 3, 1, 14, 77, 42, 21, 1, 45, 0, 0, 0, 0, 0, 183, 28, 49, 1, 0, 0, 3, 0, 0, 0, 0, 18, 0, 0, 5, 0, 19, 5, 0, 27, 0, 0, 77, 1, 3, 2, 0, 0, 22, 102, 0, 0, 0, 0, 0, 0, 0, 4, 12, 2, 0, 0, 1, 0, 40, 0, 1, 2, 27, 0, 2, 0, 0, 0, 0, 3, 1, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 53, 69, 15, 0, 2, 4, 6, 8, 0, 0, 0, 18, 38, 0, 2, 18, 34, 1, 109, 5, 15, 0, 64, 0, 1, 0, 1, 3, 5, 7, 18, 1, 0, 0, 3, 3, 0, 19, 0, 8, 26, 50, 15, 0, 19, 5, 17, 121, 1, 0, 0, 0, 0, 4, 1, 14, 1, 25, 0, 14, 0, 59, 243, 80, 69, 14, 9, 38, 37, 48, 293, 7, 10, 19, 24, 91, 1, 0, 0, 0, 0, 148, 3, 26, 12, 77, 0, 7, 0, 1, 0, 17, 0, 7, 11, 6, 50, 1, 0, 0, 0, 171, 8],
269+
"outcome_offset": [-0.22314355131420971, -0.51082562376599072, 0, 0, 0.13353139262452005, 0, -0.22314355131420971, 0.13353139262452005, 0, 0.13353139262452005, 0, 0, 0, -0.22314355131420971, 0, -0.22314355131420971, -0.22314355131420971, 0, 0, 0, 0, 0, 0, -0.1541506798272585, 0, 0, -0.22314355131420971, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -0.22314355131420971, 0.45198512374305638, 0.35667494393873334, 0, 0, 0, -0.22314355131420971, 0, 0, 0, 0, 0, 0, 0.13353139262452005, 0.13353139262452005, 0, 0, -0.22314355131420971, 0, -0.089612158689687402, 0, 0.25131442828090944, 0, 0, -0.25951119548508511, 0, 0.35667494393873334, 0, -0.25951119548508511, 0, 0, 0, 0, 0, -0.51082562376599072, 0, 0, 0, 0, 0.25131442828090944, 0, 0, 0, 0, 0, 0, 0.13353139262452005, 0, 0, 0, -0.22314355131420971, -0.22314355131420971, 0, 0, -0.37729423114146754, 0, 0, 0, 0, 0, 0, -0.22314355131420971, -0.1541506798272585, 0, 0, 0, 0, 0, 0, -0.916290731874155, -0.22314355131420971, 0.13353139262452005, 0, -0.22314355131420971, 0, -1.6094379124341003, 0, 0, 0, 0, 0, 0, 0, 0, 0.13353139262452005, 0.13353139262452005, 0, -0.22314355131420971, 0, -0.22314355131420971, 0, 0, 1.4552872326068431, -0.22314355131420971, -0.22314355131420971, -0.22314355131420971, 0, 0, 0.25131442828090944, 0, 0, 0, 0, 0, -0.7827593392496327, 0, 0.88730319500090338, 0.35667494393873334, 0, 0, 0.13353139262452005, 0, 0.13353139262452005, 0, 0, 0, -0.22314355131420971, 0.13353139262452005, 0, -0.22314355131420971, 0.45198512374305638, 0, 0.13353139262452005, 0.8266785731844698, 0, -0.55961578793542355, -0.22314355131420971, -0.1541506798272585, -0.22314355131420971, -0.1541506798272585, 0.8266785731844698, -0.22314355131420971, 0, -0.1541506798272585, 0, 0.028170876966697733, -0.51082562376599072, -0.1541506798272585, 0, -0.22314355131420971, 0, 0.39589565709201657, 0, 0, 0, -0.22314355131420971, -0.22314355131420971, 0, 0, -0.1541506798272585, -0.22314355131420971, 0, -0.22314355131420971, 0, 0, -0.22314355131420971, 0, 0, 0, 0.35667494393873334, 0.53899650073268446, -0.25951119548508511, -0.22314355131420971, 0, 0.61903920840622506, 0, 0, 0, 0, 0, 0, -0.22314355131420971, 0, 0, 0.8266785731844698, -0.22314355131420971, 0.35667494393873334, 0.028170876966697733, 0.53899650073268446, 0.13353139262452005, 0.13353139262452005, 0, 0, 0.13353139262452005, -0.22314355131420971, 0.35667494393873334, -0.089612158689687402, 0.13353139262452005, 0.25131442828090944, -0.51082562376599072, 0, -0.22314355131420971, 0.13353139262452005, 0, 0.35667494393873334, 0, 0, 0, 0, 0, 0, 0, 0, 0.13353139262452005, -0.1541506798272585, 0, 0, 0, 0, 0, 0, 0, -0.51082562376599072, 0.028170876966697733, 0, 0, -0.37729423114146754, -0.22314355131420971, 0, -0.22314355131420971, 0.39589565709201657, 0, 0, 0, 0],
270+
"beta_prior_scale": 2.5,
271+
"alpha_prior_scale": 5
272+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
data {
2+
int<lower=1> K;
3+
int<lower=1> N;
4+
matrix[N,K] x;
5+
array[N] int y;
6+
vector[N] outcome_offset;
7+
8+
real beta_prior_scale;
9+
real alpha_prior_scale;
10+
}
11+
parameters {
12+
vector[K] beta;
13+
real intercept;
14+
}
15+
model {
16+
y ~ poisson(exp(x * beta + intercept + outcome_offset));
17+
beta ~ normal(0,beta_prior_scale);
18+
intercept ~ normal(0,alpha_prior_scale);
19+
}
20+
generated quantities {
21+
vector[N] log_lik;
22+
for (n in 1:N)
23+
log_lik[n] = poisson_lpmf(y[n] | exp(x[n] * beta + intercept + outcome_offset[n]));
24+
}

0 commit comments

Comments
 (0)