Skip to content

Commit 9994f7f

Browse files
authored
Merge pull request #482 from stan-dev/draws_format_arg
Custom draws_format output
2 parents 5f7f4db + ab2051c commit 9994f7f

13 files changed

+485
-86
lines changed

NEWS.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,13 @@ Stan programs requires CmdStan >= 2.26. (#434)
2828

2929
* Suppressing compilation messages when not in interactive mode. (#462, @wlandau)
3030

31-
* Add a new `error_on_NA` argument to `cmdstan_version()` to optionally return `NULL`
32-
if the CmdStan path is not found (#467, @wlandau).
31+
* New `error_on_NA` argument for `cmdstan_version()` to optionally return `NULL`
32+
(instead of erroring) if the CmdStan path is not found (#467, @wlandau).
33+
34+
* New `format` argument for `$draws()`, `$sampler_diagnostics()`,
35+
`read_cmdstan_csv()`, and `as_cmdstan_fit`(). This controls the format of the
36+
draws returned or stored in the object. Changing the format can improve speed
37+
and memory usage for large models. (#482)
3338

3439
# cmdstanr 0.3.0
3540

R/csv.R

Lines changed: 59 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
#' Read CmdStan CSV files into R
22
#'
33
#' @description `read_cmdstan_csv()` is used internally by CmdStanR to read
4-
#' CmdStan's output CSV files into \R. It can
5-
#' also be used by CmdStan users as a more flexible and efficient alternative
6-
#' to `rstan::read_stan_csv()`. See the **Value** section for details on the
7-
#' structure of the returned list.
4+
#' CmdStan's output CSV files into \R. It can also be used by CmdStan users as
5+
#' a more flexible and efficient alternative to `rstan::read_stan_csv()`. See
6+
#' the **Value** section for details on the structure of the returned list.
87
#'
98
#' It is also possible to create CmdStanR's fitted model objects directly from
109
#' CmdStan CSV files using the `as_cmdstan_fit()` function.
@@ -22,6 +21,10 @@
2221
#' @param sampler_diagnostics Works the same way as `variables` but for sampler
2322
#' diagnostic variables (e.g., `"treedepth__"`, `"accept_stat__"`, etc.).
2423
#' Ignored if the model was not fit using MCMC.
24+
#' @param format The format for storing the draws or point estimates. The
25+
#' default depends on the method used to fit the model. See
26+
#' [draws][fit-method-draws] for details, in particular the note about speed
27+
#' and memory for models with many parameters.
2528
#'
2629
#' @return
2730
#'
@@ -49,14 +52,16 @@
4952
#' or their diagonals, depending on the type of metric used.
5053
#' * `step_size`: A list (one element per chain) of the step sizes used.
5154
#' * `warmup_draws`: If `save_warmup` was `TRUE` when fitting the model then a
52-
#' [`draws_array`][posterior::draws_array] of warmup draws.
53-
#' * `post_warmup_draws`: A [`draws_array`][posterior::draws_array] of
54-
#' post-warmup draws.
55+
#' [`draws_array`][posterior::draws_array] (or different format if `format` is
56+
#' specified) of warmup draws.
57+
#' * `post_warmup_draws`: A [`draws_array`][posterior::draws_array] (or
58+
#' different format if `format` is specified) of post-warmup draws.
5559
#' * `warmup_sampler_diagnostics`: If `save_warmup` was `TRUE` when fitting the
56-
#' model then a [`draws_array`][posterior::draws_array] of warmup draws of the
57-
#' sampler diagnostic variables.
58-
#' * `post_warmup_sampler_diagnostics`: A [`draws_array`][posterior::draws_array]
59-
#' of post-warmup draws of the sampler diagnostic variables.
60+
#' model then a [`draws_array`][posterior::draws_array] (or different format if
61+
#' `format` is specified) of warmup draws of the sampler diagnostic variables.
62+
#' * `post_warmup_sampler_diagnostics`: A
63+
#' [`draws_array`][posterior::draws_array] (or different format if `format` is
64+
#' specified) of post-warmup draws of the sampler diagnostic variables.
6065
#'
6166
#' For [optimization][model-method-optimize] the returned list also includes the
6267
#' following components:
@@ -66,8 +71,9 @@
6671
#' For [variational inference][model-method-variational] the returned list also
6772
#' includes the following components:
6873
#'
69-
#' * `draws`: A [`draws_matrix`][posterior::draws_matrix] of draws from the
70-
#' approximate posterior distribution.
74+
#' * `draws`: A [`draws_matrix`][posterior::draws_matrix] (or different format
75+
#' if `format` is specified) of draws from the approximate posterior
76+
#' distribution.
7177
#'
7278
#' For [standalone generated quantities][model-method-generate-quantities] the
7379
#' returned list also includes the following components:
@@ -117,7 +123,13 @@
117123
#'
118124
read_cmdstan_csv <- function(files,
119125
variables = NULL,
120-
sampler_diagnostics = NULL) {
126+
sampler_diagnostics = NULL,
127+
format = getOption("cmdstanr_draws_format", NULL)) {
128+
valid_draws_formats <- c("draws_array", "array", "draws_matrix", "matrix",
129+
"draws_list", "list", "draws_df", "df", "data.frame")
130+
if (!is.null(format) && !(format %in% valid_draws_formats)) {
131+
stop("The supplied draws format is not valid.", call. = FALSE)
132+
}
121133
checkmate::assert_file_exists(files, access = "r", extension = "csv")
122134
metadata <- NULL
123135
warmup_draws <- list()
@@ -171,7 +183,7 @@ read_cmdstan_csv <- function(files,
171183
uniq_seed <- unique(metadata$seed)
172184
if (length(uniq_seed) == 1) {
173185
metadata$seed <- uniq_seed
174-
}
186+
}
175187
if (is.null(variables)) { # variables = NULL returns all
176188
variables <- metadata$model_params
177189
} else if (!any(nzchar(variables))) { # if variables = "" returns none
@@ -224,14 +236,14 @@ read_cmdstan_csv <- function(files,
224236
)
225237
)
226238
if (metadata$method == "sample" && metadata$save_warmup == 1 && num_warmup_draws > 0) {
227-
warmup_sampler_diagnostics[[warmup_sd_id]] <-
239+
warmup_sampler_diagnostics[[warmup_sd_id]] <-
228240
post_warmup_sampler_diagnostics[[post_warmup_sd_id]][1:num_warmup_draws,,drop = FALSE]
229241
if (num_post_warmup_draws > 0) {
230-
post_warmup_sampler_diagnostics[[post_warmup_sd_id]] <-
242+
post_warmup_sampler_diagnostics[[post_warmup_sd_id]] <-
231243
post_warmup_sampler_diagnostics[[post_warmup_sd_id]][(num_warmup_draws+1):(num_warmup_draws + num_post_warmup_draws),,drop = FALSE]
232244
} else {
233245
post_warmup_sampler_diagnostics[[post_warmup_sd_id]] <- NULL
234-
}
246+
}
235247
}
236248
}
237249
if (length(variables) > 0) {
@@ -245,7 +257,7 @@ read_cmdstan_csv <- function(files,
245257
)
246258
)
247259
if (metadata$method == "sample" && metadata$save_warmup == 1 && num_warmup_draws > 0) {
248-
warmup_draws[[warmup_draws_list_id]] <-
260+
warmup_draws[[warmup_draws_list_id]] <-
249261
draws[[draws_list_id]][1:num_warmup_draws,,drop = FALSE]
250262
if (num_post_warmup_draws > 0) {
251263
draws[[draws_list_id]] <- draws[[draws_list_id]][(num_warmup_draws+1):(num_warmup_draws + num_post_warmup_draws),,drop = FALSE]
@@ -271,8 +283,12 @@ read_cmdstan_csv <- function(files,
271283
metadata$stan_variables <- names(model_param_dims)
272284

273285
if (metadata$method == "sample") {
286+
if (is.null(format)) {
287+
format <- "draws_array"
288+
}
289+
as_draws_format <- as_draws_format_fun(format)
274290
if (length(warmup_draws) > 0) {
275-
warmup_draws <- posterior::as_draws_array(warmup_draws)
291+
warmup_draws <- do.call(as_draws_format, list(warmup_draws))
276292
posterior::variables(warmup_draws) <- repaired_variables
277293
if (posterior::niterations(warmup_draws) == 0) {
278294
warmup_draws <- NULL
@@ -281,7 +297,7 @@ read_cmdstan_csv <- function(files,
281297
warmup_draws <- NULL
282298
}
283299
if (length(draws) > 0) {
284-
draws <- posterior::as_draws_array(draws)
300+
draws <- do.call(as_draws_format, list(draws))
285301
posterior::variables(draws) <- repaired_variables
286302
if (posterior::niterations(draws) == 0) {
287303
draws <- NULL
@@ -290,15 +306,15 @@ read_cmdstan_csv <- function(files,
290306
draws <- NULL
291307
}
292308
if (length(warmup_sampler_diagnostics) > 0) {
293-
warmup_sampler_diagnostics <- posterior::as_draws_array(warmup_sampler_diagnostics)
309+
warmup_sampler_diagnostics <- do.call(as_draws_format, list(warmup_sampler_diagnostics))
294310
if (posterior::niterations(warmup_sampler_diagnostics) == 0) {
295311
warmup_sampler_diagnostics <- NULL
296312
}
297313
} else {
298314
warmup_sampler_diagnostics <- NULL
299315
}
300316
if (length(post_warmup_sampler_diagnostics) > 0) {
301-
post_warmup_sampler_diagnostics <- posterior::as_draws_array(post_warmup_sampler_diagnostics)
317+
post_warmup_sampler_diagnostics <- do.call(as_draws_format, list(post_warmup_sampler_diagnostics))
302318
if (posterior::niterations(post_warmup_sampler_diagnostics) == 0) {
303319
post_warmup_sampler_diagnostics <- NULL
304320
}
@@ -316,24 +332,31 @@ read_cmdstan_csv <- function(files,
316332
post_warmup_sampler_diagnostics = post_warmup_sampler_diagnostics
317333
)
318334
} else if (metadata$method == "variational") {
319-
variational_draws <- posterior::as_draws_matrix(
320-
draws[[1]][-1, colnames(draws[[1]]) != "lp__", drop=FALSE]
321-
)
335+
if (is.null(format)) {
336+
format <- "draws_matrix"
337+
}
338+
as_draws_format <- as_draws_format_fun(format)
339+
variational_draws <- do.call(as_draws_format, list(draws[[1]][-1, colnames(draws[[1]]) != "lp__", drop=FALSE]))
322340
if (!is.null(variational_draws)) {
323341
if ("log_p__" %in% posterior::variables(variational_draws)) {
324342
variational_draws <- posterior::rename_variables(variational_draws, lp__ = "log_p__")
325343
}
326344
if ("log_g__" %in% posterior::variables(variational_draws)) {
327345
variational_draws <- posterior::rename_variables(variational_draws, lp_approx__ = "log_g__")
328-
}
346+
}
329347
posterior::variables(variational_draws) <- repaired_variables
330348
}
331349
list(
332350
metadata = metadata,
333351
draws = variational_draws
334352
)
335353
} else if (metadata$method == "optimize") {
336-
point_estimates <- posterior::as_draws_matrix(draws[[1]][1,, drop=FALSE])[, variables]
354+
if (is.null(format)) {
355+
format <- "draws_matrix"
356+
}
357+
as_draws_format <- as_draws_format_fun(format)
358+
point_estimates <- do.call(as_draws_format, list(draws[[1]][1,, drop=FALSE]))
359+
point_estimates <- posterior::subset_draws(point_estimates, variable = variables)
337360
if (!is.null(point_estimates)) {
338361
posterior::variables(point_estimates) <- repaired_variables
339362
}
@@ -342,7 +365,11 @@ read_cmdstan_csv <- function(files,
342365
point_estimates = point_estimates
343366
)
344367
} else if (metadata$method == "generate_quantities") {
345-
draws <- posterior::as_draws_array(draws)
368+
if (is.null(format)) {
369+
format <- "draws_array"
370+
}
371+
as_draws_format <- as_draws_format_fun(format)
372+
draws <- do.call(as_draws_format, list(draws))
346373
if (!is.null(draws)) {
347374
posterior::variables(draws) <- repaired_variables
348375
}
@@ -374,8 +401,8 @@ read_sample_csv <- function(files,
374401
#' be performed after reading in the files? The default is `TRUE` but set to
375402
#' `FALSE` to avoid checking for problems with divergences and treedepth.
376403
#'
377-
as_cmdstan_fit <- function(files, check_diagnostics = TRUE) {
378-
csv_contents <- read_cmdstan_csv(files)
404+
as_cmdstan_fit <- function(files, check_diagnostics = TRUE, format = getOption("cmdstanr_draws_format", NULL)) {
405+
csv_contents <- read_cmdstan_csv(files, format = format)
379406
switch(
380407
csv_contents$metadata$method,
381408
"sample" = CmdStanMCMC_CSV$new(csv_contents, files, check_diagnostics),
@@ -656,7 +683,7 @@ read_csv_metadata <- function(csv_file) {
656683
check_csv_metadata_matches <- function(csv_metadata) {
657684
model_name <- sapply(csv_metadata, function(x) x$model_name)
658685
if (!all(model_name == model_name[1])) {
659-
stop("Supplied CSV files were not generated with the same model!", call. = FALSE)
686+
stop("Supplied CSV files were not generated with the same model!", call. = FALSE)
660687
}
661688
method <- sapply(csv_metadata, function(x) x$method)
662689
if (!all(method == method[1])) {

0 commit comments

Comments
 (0)