1
1
# ' Read CmdStan CSV files into R
2
2
# '
3
3
# ' @description `read_cmdstan_csv()` is used internally by CmdStanR to read
4
- # ' CmdStan's output CSV files into \R. It can
5
- # ' also be used by CmdStan users as a more flexible and efficient alternative
6
- # ' to `rstan::read_stan_csv()`. See the **Value** section for details on the
7
- # ' structure of the returned list.
4
+ # ' CmdStan's output CSV files into \R. It can also be used by CmdStan users as
5
+ # ' a more flexible and efficient alternative to `rstan::read_stan_csv()`. See
6
+ # ' the **Value** section for details on the structure of the returned list.
8
7
# '
9
8
# ' It is also possible to create CmdStanR's fitted model objects directly from
10
9
# ' CmdStan CSV files using the `as_cmdstan_fit()` function.
22
21
# ' @param sampler_diagnostics Works the same way as `variables` but for sampler
23
22
# ' diagnostic variables (e.g., `"treedepth__"`, `"accept_stat__"`, etc.).
24
23
# ' Ignored if the model was not fit using MCMC.
24
+ # ' @param format The format for storing the draws or point estimates. The
25
+ # ' default depends on the method used to fit the model. See
26
+ # ' [draws][fit-method-draws] for details, in particular the note about speed
27
+ # ' and memory for models with many parameters.
25
28
# '
26
29
# ' @return
27
30
# '
49
52
# ' or their diagonals, depending on the type of metric used.
50
53
# ' * `step_size`: A list (one element per chain) of the step sizes used.
51
54
# ' * `warmup_draws`: If `save_warmup` was `TRUE` when fitting the model then a
52
- # ' [`draws_array`][posterior::draws_array] of warmup draws.
53
- # ' * `post_warmup_draws`: A [`draws_array`][posterior::draws_array] of
54
- # ' post-warmup draws.
55
+ # ' [`draws_array`][posterior::draws_array] (or different format if `format` is
56
+ # ' specified) of warmup draws.
57
+ # ' * `post_warmup_draws`: A [`draws_array`][posterior::draws_array] (or
58
+ # ' different format if `format` is specified) of post-warmup draws.
55
59
# ' * `warmup_sampler_diagnostics`: If `save_warmup` was `TRUE` when fitting the
56
- # ' model then a [`draws_array`][posterior::draws_array] of warmup draws of the
57
- # ' sampler diagnostic variables.
58
- # ' * `post_warmup_sampler_diagnostics`: A [`draws_array`][posterior::draws_array]
59
- # ' of post-warmup draws of the sampler diagnostic variables.
60
+ # ' model then a [`draws_array`][posterior::draws_array] (or different format if
61
+ # ' `format` is specified) of warmup draws of the sampler diagnostic variables.
62
+ # ' * `post_warmup_sampler_diagnostics`: A
63
+ # ' [`draws_array`][posterior::draws_array] (or different format if `format` is
64
+ # ' specified) of post-warmup draws of the sampler diagnostic variables.
60
65
# '
61
66
# ' For [optimization][model-method-optimize] the returned list also includes the
62
67
# ' following components:
66
71
# ' For [variational inference][model-method-variational] the returned list also
67
72
# ' includes the following components:
68
73
# '
69
- # ' * `draws`: A [`draws_matrix`][posterior::draws_matrix] of draws from the
70
- # ' approximate posterior distribution.
74
+ # ' * `draws`: A [`draws_matrix`][posterior::draws_matrix] (or different format
75
+ # ' if `format` is specified) of draws from the approximate posterior
76
+ # ' distribution.
71
77
# '
72
78
# ' For [standalone generated quantities][model-method-generate-quantities] the
73
79
# ' returned list also includes the following components:
117
123
# '
118
124
read_cmdstan_csv <- function (files ,
119
125
variables = NULL ,
120
- sampler_diagnostics = NULL ) {
126
+ sampler_diagnostics = NULL ,
127
+ format = getOption(" cmdstanr_draws_format" , NULL )) {
128
+ valid_draws_formats <- c(" draws_array" , " array" , " draws_matrix" , " matrix" ,
129
+ " draws_list" , " list" , " draws_df" , " df" , " data.frame" )
130
+ if (! is.null(format ) && ! (format %in% valid_draws_formats )) {
131
+ stop(" The supplied draws format is not valid." , call. = FALSE )
132
+ }
121
133
checkmate :: assert_file_exists(files , access = " r" , extension = " csv" )
122
134
metadata <- NULL
123
135
warmup_draws <- list ()
@@ -171,7 +183,7 @@ read_cmdstan_csv <- function(files,
171
183
uniq_seed <- unique(metadata $ seed )
172
184
if (length(uniq_seed ) == 1 ) {
173
185
metadata $ seed <- uniq_seed
174
- }
186
+ }
175
187
if (is.null(variables )) { # variables = NULL returns all
176
188
variables <- metadata $ model_params
177
189
} else if (! any(nzchar(variables ))) { # if variables = "" returns none
@@ -224,14 +236,14 @@ read_cmdstan_csv <- function(files,
224
236
)
225
237
)
226
238
if (metadata $ method == " sample" && metadata $ save_warmup == 1 && num_warmup_draws > 0 ) {
227
- warmup_sampler_diagnostics [[warmup_sd_id ]] <-
239
+ warmup_sampler_diagnostics [[warmup_sd_id ]] <-
228
240
post_warmup_sampler_diagnostics [[post_warmup_sd_id ]][1 : num_warmup_draws ,,drop = FALSE ]
229
241
if (num_post_warmup_draws > 0 ) {
230
- post_warmup_sampler_diagnostics [[post_warmup_sd_id ]] <-
242
+ post_warmup_sampler_diagnostics [[post_warmup_sd_id ]] <-
231
243
post_warmup_sampler_diagnostics [[post_warmup_sd_id ]][(num_warmup_draws + 1 ): (num_warmup_draws + num_post_warmup_draws ),,drop = FALSE ]
232
244
} else {
233
245
post_warmup_sampler_diagnostics [[post_warmup_sd_id ]] <- NULL
234
- }
246
+ }
235
247
}
236
248
}
237
249
if (length(variables ) > 0 ) {
@@ -245,7 +257,7 @@ read_cmdstan_csv <- function(files,
245
257
)
246
258
)
247
259
if (metadata $ method == " sample" && metadata $ save_warmup == 1 && num_warmup_draws > 0 ) {
248
- warmup_draws [[warmup_draws_list_id ]] <-
260
+ warmup_draws [[warmup_draws_list_id ]] <-
249
261
draws [[draws_list_id ]][1 : num_warmup_draws ,,drop = FALSE ]
250
262
if (num_post_warmup_draws > 0 ) {
251
263
draws [[draws_list_id ]] <- draws [[draws_list_id ]][(num_warmup_draws + 1 ): (num_warmup_draws + num_post_warmup_draws ),,drop = FALSE ]
@@ -271,8 +283,12 @@ read_cmdstan_csv <- function(files,
271
283
metadata $ stan_variables <- names(model_param_dims )
272
284
273
285
if (metadata $ method == " sample" ) {
286
+ if (is.null(format )) {
287
+ format <- " draws_array"
288
+ }
289
+ as_draws_format <- as_draws_format_fun(format )
274
290
if (length(warmup_draws ) > 0 ) {
275
- warmup_draws <- posterior :: as_draws_array( warmup_draws )
291
+ warmup_draws <- do.call( as_draws_format , list ( warmup_draws ) )
276
292
posterior :: variables(warmup_draws ) <- repaired_variables
277
293
if (posterior :: niterations(warmup_draws ) == 0 ) {
278
294
warmup_draws <- NULL
@@ -281,7 +297,7 @@ read_cmdstan_csv <- function(files,
281
297
warmup_draws <- NULL
282
298
}
283
299
if (length(draws ) > 0 ) {
284
- draws <- posterior :: as_draws_array( draws )
300
+ draws <- do.call( as_draws_format , list ( draws ) )
285
301
posterior :: variables(draws ) <- repaired_variables
286
302
if (posterior :: niterations(draws ) == 0 ) {
287
303
draws <- NULL
@@ -290,15 +306,15 @@ read_cmdstan_csv <- function(files,
290
306
draws <- NULL
291
307
}
292
308
if (length(warmup_sampler_diagnostics ) > 0 ) {
293
- warmup_sampler_diagnostics <- posterior :: as_draws_array( warmup_sampler_diagnostics )
309
+ warmup_sampler_diagnostics <- do.call( as_draws_format , list ( warmup_sampler_diagnostics ) )
294
310
if (posterior :: niterations(warmup_sampler_diagnostics ) == 0 ) {
295
311
warmup_sampler_diagnostics <- NULL
296
312
}
297
313
} else {
298
314
warmup_sampler_diagnostics <- NULL
299
315
}
300
316
if (length(post_warmup_sampler_diagnostics ) > 0 ) {
301
- post_warmup_sampler_diagnostics <- posterior :: as_draws_array( post_warmup_sampler_diagnostics )
317
+ post_warmup_sampler_diagnostics <- do.call( as_draws_format , list ( post_warmup_sampler_diagnostics ) )
302
318
if (posterior :: niterations(post_warmup_sampler_diagnostics ) == 0 ) {
303
319
post_warmup_sampler_diagnostics <- NULL
304
320
}
@@ -316,24 +332,31 @@ read_cmdstan_csv <- function(files,
316
332
post_warmup_sampler_diagnostics = post_warmup_sampler_diagnostics
317
333
)
318
334
} else if (metadata $ method == " variational" ) {
319
- variational_draws <- posterior :: as_draws_matrix(
320
- draws [[1 ]][- 1 , colnames(draws [[1 ]]) != " lp__" , drop = FALSE ]
321
- )
335
+ if (is.null(format )) {
336
+ format <- " draws_matrix"
337
+ }
338
+ as_draws_format <- as_draws_format_fun(format )
339
+ variational_draws <- do.call(as_draws_format , list (draws [[1 ]][- 1 , colnames(draws [[1 ]]) != " lp__" , drop = FALSE ]))
322
340
if (! is.null(variational_draws )) {
323
341
if (" log_p__" %in% posterior :: variables(variational_draws )) {
324
342
variational_draws <- posterior :: rename_variables(variational_draws , lp__ = " log_p__" )
325
343
}
326
344
if (" log_g__" %in% posterior :: variables(variational_draws )) {
327
345
variational_draws <- posterior :: rename_variables(variational_draws , lp_approx__ = " log_g__" )
328
- }
346
+ }
329
347
posterior :: variables(variational_draws ) <- repaired_variables
330
348
}
331
349
list (
332
350
metadata = metadata ,
333
351
draws = variational_draws
334
352
)
335
353
} else if (metadata $ method == " optimize" ) {
336
- point_estimates <- posterior :: as_draws_matrix(draws [[1 ]][1 ,, drop = FALSE ])[, variables ]
354
+ if (is.null(format )) {
355
+ format <- " draws_matrix"
356
+ }
357
+ as_draws_format <- as_draws_format_fun(format )
358
+ point_estimates <- do.call(as_draws_format , list (draws [[1 ]][1 ,, drop = FALSE ]))
359
+ point_estimates <- posterior :: subset_draws(point_estimates , variable = variables )
337
360
if (! is.null(point_estimates )) {
338
361
posterior :: variables(point_estimates ) <- repaired_variables
339
362
}
@@ -342,7 +365,11 @@ read_cmdstan_csv <- function(files,
342
365
point_estimates = point_estimates
343
366
)
344
367
} else if (metadata $ method == " generate_quantities" ) {
345
- draws <- posterior :: as_draws_array(draws )
368
+ if (is.null(format )) {
369
+ format <- " draws_array"
370
+ }
371
+ as_draws_format <- as_draws_format_fun(format )
372
+ draws <- do.call(as_draws_format , list (draws ))
346
373
if (! is.null(draws )) {
347
374
posterior :: variables(draws ) <- repaired_variables
348
375
}
@@ -374,8 +401,8 @@ read_sample_csv <- function(files,
374
401
# ' be performed after reading in the files? The default is `TRUE` but set to
375
402
# ' `FALSE` to avoid checking for problems with divergences and treedepth.
376
403
# '
377
- as_cmdstan_fit <- function (files , check_diagnostics = TRUE ) {
378
- csv_contents <- read_cmdstan_csv(files )
404
+ as_cmdstan_fit <- function (files , check_diagnostics = TRUE , format = getOption( " cmdstanr_draws_format " , NULL ) ) {
405
+ csv_contents <- read_cmdstan_csv(files , format = format )
379
406
switch (
380
407
csv_contents $ metadata $ method ,
381
408
" sample" = CmdStanMCMC_CSV $ new(csv_contents , files , check_diagnostics ),
@@ -656,7 +683,7 @@ read_csv_metadata <- function(csv_file) {
656
683
check_csv_metadata_matches <- function (csv_metadata ) {
657
684
model_name <- sapply(csv_metadata , function (x ) x $ model_name )
658
685
if (! all(model_name == model_name [1 ])) {
659
- stop(" Supplied CSV files were not generated with the same model!" , call. = FALSE )
686
+ stop(" Supplied CSV files were not generated with the same model!" , call. = FALSE )
660
687
}
661
688
method <- sapply(csv_metadata , function (x ) x $ method )
662
689
if (! all(method == method [1 ])) {
0 commit comments