Skip to content

Commit a27d933

Browse files
authored
Merge pull request #419 from stan-dev/faster_CSV_read
Faster CSV read with multiple chains
2 parents 3a652c9 + ef3d394 commit a27d933

File tree

5 files changed

+132
-93
lines changed

5 files changed

+132
-93
lines changed

R/csv.R

Lines changed: 33 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@ read_cmdstan_csv <- function(files,
120120
sampler_diagnostics = NULL) {
121121
checkmate::assert_file_exists(files, access = "r", extension = "csv")
122122
metadata <- NULL
123-
warmup_draws <- NULL
124-
warmup_sampler_diagnostics_draws <- NULL
125-
post_warmup_draws <- NULL
126-
post_warmup_sampler_diagnostics_draws <- NULL
127-
generated_quantities <- NULL
123+
warmup_draws <- list()
124+
post_warmup_draws <- list()
125+
warmup_sampler_diagnostics_draws <- list()
126+
post_warmup_sampler_diagnostics_draws <- list()
127+
generated_quantities <- list()
128128
variational_draws <- NULL
129129
point_estimates <- NULL
130130
inv_metric <- list()
@@ -241,49 +241,25 @@ read_cmdstan_csv <- function(files,
241241
if (metadata$method == "sample") {
242242
if (metadata$save_warmup == 1) {
243243
if (length(variables) > 0) {
244-
warmup_draws <- posterior::bind_draws(
245-
warmup_draws,
246-
posterior::as_draws_array(draws[1:num_warmup_draws, variables, drop = FALSE]),
247-
along="chain"
248-
)
244+
warmup_draws[[length(warmup_draws) + 1]] <- draws[1:num_warmup_draws, variables, drop = FALSE]
249245
if (num_post_warmup_draws > 0) {
250-
post_warmup_draws <- posterior::bind_draws(
251-
post_warmup_draws,
252-
posterior::as_draws_array(draws[(num_warmup_draws+1):all_draws, variables, drop = FALSE]),
253-
along="chain"
254-
)
246+
post_warmup_draws[[length(post_warmup_draws) + 1]] <- draws[(num_warmup_draws+1):all_draws, variables, drop = FALSE]
255247
}
256248
}
257249
if (length(sampler_diagnostics) > 0) {
258-
warmup_sampler_diagnostics_draws <- posterior::bind_draws(
259-
warmup_sampler_diagnostics_draws,
260-
posterior::as_draws_array(draws[1:num_warmup_draws, sampler_diagnostics, drop = FALSE]),
261-
along="chain"
262-
)
250+
warmup_sampler_diagnostics_draws[[length(warmup_sampler_diagnostics_draws) + 1]] <- draws[1:num_warmup_draws, sampler_diagnostics, drop = FALSE]
263251
if (num_post_warmup_draws > 0) {
264-
post_warmup_sampler_diagnostics_draws <- posterior::bind_draws(
265-
post_warmup_sampler_diagnostics_draws,
266-
posterior::as_draws_array(draws[(num_warmup_draws+1):all_draws, sampler_diagnostics, drop = FALSE]),
267-
along="chain"
268-
)
252+
post_warmup_sampler_diagnostics_draws[[length(post_warmup_sampler_diagnostics_draws) + 1]] <- draws[(num_warmup_draws+1):all_draws, sampler_diagnostics, drop = FALSE]
269253
}
270254
}
271255
} else {
272256
warmup_draws <- NULL
273257
warmup_sampler_diagnostics_draws <- NULL
274258
if (length(variables) > 0) {
275-
post_warmup_draws <- posterior::bind_draws(
276-
post_warmup_draws,
277-
posterior::as_draws_array(draws[, variables, drop = FALSE]),
278-
along="chain"
279-
)
259+
post_warmup_draws[[length(post_warmup_draws) + 1]] <- draws[, variables, drop = FALSE]
280260
}
281261
if (length(sampler_diagnostics) > 0 && all(metadata$algorithm != "fixed_param")) {
282-
post_warmup_sampler_diagnostics_draws <- posterior::bind_draws(
283-
post_warmup_sampler_diagnostics_draws,
284-
posterior::as_draws_array(draws[, sampler_diagnostics, drop = FALSE]),
285-
along="chain"
286-
)
262+
post_warmup_sampler_diagnostics_draws[[length(post_warmup_sampler_diagnostics_draws) + 1]] <- draws[, sampler_diagnostics, drop = FALSE]
287263
}
288264
}
289265
} else if (metadata$method == "variational") {
@@ -300,9 +276,7 @@ read_cmdstan_csv <- function(files,
300276
} else if (metadata$method == "optimize") {
301277
point_estimates <- posterior::as_draws_matrix(draws[1,, drop=FALSE])[, variables]
302278
} else if (metadata$method == "generate_quantities") {
303-
generated_quantities <- posterior::bind_draws(generated_quantities,
304-
posterior::as_draws_array(draws),
305-
along="chain")
279+
generated_quantities[[length(generated_quantities) + 1]] <- draws
306280
}
307281
}
308282
}
@@ -313,7 +287,6 @@ read_cmdstan_csv <- function(files,
313287
}
314288

315289
metadata$inv_metric <- NULL
316-
metadata$lines_to_skip <- NULL
317290
metadata$model_params <- repair_variable_names(metadata$model_params)
318291
repaired_variables <- repair_variable_names(variables)
319292
if (metadata$method == "variational") {
@@ -330,12 +303,16 @@ read_cmdstan_csv <- function(files,
330303
metadata$stan_variables <- names(model_param_dims)
331304

332305
if (metadata$method == "sample") {
306+
warmup_draws <- bind_list_of_draws_array(warmup_draws)
333307
if (!is.null(warmup_draws)) {
334308
posterior::variables(warmup_draws) <- repaired_variables
335309
}
310+
post_warmup_draws <- bind_list_of_draws_array(post_warmup_draws)
336311
if (!is.null(post_warmup_draws)) {
337312
posterior::variables(post_warmup_draws) <- repaired_variables
338313
}
314+
warmup_sampler_diagnostics_draws <- bind_list_of_draws_array(warmup_sampler_diagnostics_draws)
315+
post_warmup_sampler_diagnostics_draws <- bind_list_of_draws_array(post_warmup_sampler_diagnostics_draws)
339316
list(
340317
metadata = metadata,
341318
time = list(total = NA_integer_, chains = time),
@@ -363,6 +340,7 @@ read_cmdstan_csv <- function(files,
363340
point_estimates = point_estimates
364341
)
365342
} else if (metadata$method == "generate_quantities") {
343+
generated_quantities <- bind_list_of_draws_array(generated_quantities)
366344
if (!is.null(generated_quantities)) {
367345
posterior::variables(generated_quantities) <- repaired_variables
368346
}
@@ -422,8 +400,8 @@ CmdStanMCMC_CSV <- R6::R6Class(
422400
public = list(
423401
initialize = function(csv_contents, files, check_diagnostics = TRUE) {
424402
if (check_diagnostics) {
425-
check_divergences(csv_contents)
426-
check_sampler_transitions_treedepth(csv_contents)
403+
check_divergences(csv_contents$post_warmup_sampler_diagnostics)
404+
check_sampler_transitions_treedepth(csv_contents$post_warmup_sampler_diagnostics, csv_contents$metadata)
427405
}
428406
private$output_files_ <- files
429407
private$metadata_ <- csv_contents$metadata
@@ -708,7 +686,20 @@ check_csv_metadata_matches <- function(a, b) {
708686
list(not_matching = not_matching)
709687
}
710688

711-
689+
bind_list_of_draws_array <- function(draws, along = "chain") {
690+
if (!is.null(draws) && length(draws) > 0) {
691+
if (length(draws) > 1) {
692+
draws <- lapply(draws, posterior::as_draws_array)
693+
draws[["along"]] <- along
694+
draws <- do.call(posterior::bind_draws, draws)
695+
} else {
696+
draws <- posterior::as_draws_array(draws[[1]])
697+
}
698+
} else {
699+
draws <- NULL
700+
}
701+
draws
702+
}
712703

713704
# convert names like beta.1.1 to beta[1,1]
714705
repair_variable_names <- function(names) {

R/fit.R

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -708,17 +708,12 @@ CmdStanMCMC <- R6::R6Class(
708708
} else {
709709
if (self$runset$args$validate_csv) {
710710
fixed_param <- runset$args$method_args$fixed_param
711-
csv_contents <- read_cmdstan_csv(
712-
self$output_files(),
713-
variables = "",
714-
sampler_diagnostics =
715-
if (!fixed_param) c("treedepth__", "divergent__") else ""
716-
)
711+
private$read_csv_(variables = "",
712+
sampler_diagnostics = if (!fixed_param) c("treedepth__", "divergent__") else "")
717713
if (!fixed_param) {
718-
check_divergences(csv_contents)
719-
check_sampler_transitions_treedepth(csv_contents)
714+
check_divergences(private$sampler_diagnostics_)
715+
check_sampler_transitions_treedepth(private$sampler_diagnostics_, private$metadata_)
720716
}
721-
private$metadata_ <- csv_contents$metadata
722717
}
723718
}
724719
},
@@ -782,38 +777,54 @@ CmdStanMCMC <- R6::R6Class(
782777
private$metadata_ <- csv_contents$metadata
783778

784779
if (!is.null(csv_contents$post_warmup_draws)) {
785-
missing_variables <- !(posterior::variables(csv_contents$post_warmup_draws) %in% posterior::variables(private$draws_))
786-
private$draws_ <- posterior::bind_draws(
787-
private$draws_,
788-
csv_contents$post_warmup_draws[,,missing_variables],
789-
along="variable"
790-
)
780+
if (is.null(private$draws_)) {
781+
private$draws_ <- csv_contents$post_warmup_draws
782+
} else {
783+
missing_variables <- !(posterior::variables(csv_contents$post_warmup_draws) %in% posterior::variables(private$draws_))
784+
private$draws_ <- posterior::bind_draws(
785+
private$draws_,
786+
csv_contents$post_warmup_draws[,,missing_variables],
787+
along="variable"
788+
)
789+
}
791790
}
792791
if (!is.null(csv_contents$post_warmup_sampler_diagnostics)) {
793-
missing_variables <- !(posterior::variables(csv_contents$post_warmup_sampler_diagnostics) %in% posterior::variables(private$sampler_diagnostics_))
794-
private$sampler_diagnostics_ <- posterior::bind_draws(
795-
private$sampler_diagnostics_,
796-
csv_contents$post_warmup_sampler_diagnostics[,,missing_variables],
797-
along="variable"
798-
)
792+
if (is.null(private$sampler_diagnostics_)) {
793+
private$sampler_diagnostics_ <- csv_contents$post_warmup_sampler_diagnostics
794+
} else {
795+
missing_variables <- !(posterior::variables(csv_contents$post_warmup_sampler_diagnostics) %in% posterior::variables(private$sampler_diagnostics_))
796+
private$sampler_diagnostics_ <- posterior::bind_draws(
797+
private$sampler_diagnostics_,
798+
csv_contents$post_warmup_sampler_diagnostics[,,missing_variables],
799+
along="variable"
800+
)
801+
}
799802
}
800803
if (!is.null(csv_contents$metadata$save_warmup)
801804
&& csv_contents$metadata$save_warmup) {
802805
if (!is.null(csv_contents$warmup_draws)) {
803-
missing_variables <- !(posterior::variables(csv_contents$warmup_draws) %in% posterior::variables(private$warmup_draws_))
804-
private$warmup_draws_ <- posterior::bind_draws(
805-
private$warmup_draws_,
806-
csv_contents$warmup_draws[,,missing_variables],
807-
along="variable"
808-
)
806+
if (is.null(private$warmup_draws_)) {
807+
private$warmup_draws_ <- csv_contents$warmup_draws
808+
} else {
809+
missing_variables <- !(posterior::variables(csv_contents$warmup_draws) %in% posterior::variables(private$warmup_draws_))
810+
private$warmup_draws_ <- posterior::bind_draws(
811+
private$warmup_draws_,
812+
csv_contents$warmup_draws[,,missing_variables],
813+
along="variable"
814+
)
815+
}
809816
}
810817
if (!is.null(csv_contents$warmup_sampler_diagnostics)) {
811-
missing_variables <- !(posterior::variables(csv_contents$warmup_sampler_diagnostics) %in% posterior::variables(private$warmup_sampler_diagnostics_))
812-
private$warmup_sampler_diagnostics_ <- posterior::bind_draws(
813-
private$warmup_sampler_diagnostics_,
814-
csv_contents$warmup_sampler_diagnostics[,,missing_variables],
815-
along="variable"
816-
)
818+
if (is.null(private$warmup_sampler_diagnostics_)) {
819+
private$warmup_sampler_diagnostics_ <- csv_contents$warmup_sampler_diagnostics
820+
} else {
821+
missing_variables <- !(posterior::variables(csv_contents$warmup_sampler_diagnostics) %in% posterior::variables(private$warmup_sampler_diagnostics_))
822+
private$warmup_sampler_diagnostics_ <- posterior::bind_draws(
823+
private$warmup_sampler_diagnostics_,
824+
csv_contents$warmup_sampler_diagnostics[,,missing_variables],
825+
along="variable"
826+
)
827+
}
817828
}
818829
}
819830
invisible(self)

R/utils.R

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,9 @@ set_num_threads <- function(num_threads) {
253253
call. = FALSE)
254254
}
255255

256-
check_divergences <- function(csv_contents) {
257-
if (!is.null(csv_contents$post_warmup_sampler_diagnostics)) {
258-
divergences <- posterior::extract_variable_matrix(csv_contents$post_warmup_sampler_diagnostics, "divergent__")
256+
check_divergences <- function(post_warmup_sampler_diagnostics) {
257+
if (!is.null(post_warmup_sampler_diagnostics)) {
258+
divergences <- posterior::extract_variable_matrix(post_warmup_sampler_diagnostics, "divergent__")
259259
num_of_draws <- length(divergences)
260260
num_of_divergences <- sum(divergences)
261261
if (!is.na(num_of_divergences) && num_of_divergences > 0) {
@@ -274,17 +274,16 @@ check_divergences <- function(csv_contents) {
274274
}
275275
}
276276

277-
check_sampler_transitions_treedepth <- function(csv_contents) {
278-
if (!is.null(csv_contents$post_warmup_sampler_diagnostics)) {
279-
treedepth <- posterior::extract_variable_matrix(csv_contents$post_warmup_sampler_diagnostics, "treedepth__")
277+
check_sampler_transitions_treedepth <- function(post_warmup_sampler_diagnostics, metadata) {
278+
if (!is.null(post_warmup_sampler_diagnostics)) {
279+
treedepth <- posterior::extract_variable_matrix(post_warmup_sampler_diagnostics, "treedepth__")
280280
num_of_draws <- length(treedepth)
281-
max_treedepth <- csv_contents$metadata$max_treedepth
282-
max_treedepth_hit <- sum(treedepth >= max_treedepth)
281+
max_treedepth_hit <- sum(treedepth >= metadata$max_treedepth)
283282
if (!is.na(max_treedepth_hit) && max_treedepth_hit > 0) {
284283
percentage_max_treedepth <- (max_treedepth_hit)/num_of_draws*100
285284
message(max_treedepth_hit, " of ", num_of_draws, " (", (format(round(percentage_max_treedepth, 0), nsmall = 1)), "%)",
286-
" transitions hit the maximum treedepth limit of ", max_treedepth,
287-
" or 2^", max_treedepth, "-1 leapfrog steps.\n",
285+
" transitions hit the maximum treedepth limit of ", metadata$max_treedepth,
286+
" or 2^", metadata$max_treedepth, "-1 leapfrog steps.\n",
288287
"Trajectories that are prematurely terminated due to this limit will result in slow exploration.\n",
289288
"Increasing the max_treedepth limit can avoid this at the expense of more computation.\n",
290289
"If increasing max_treedepth does not remove warnings, try to reparameterize the model.\n")

tests/testthat/test-fit-mcmc.R

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,29 @@ test_that("draws() stops for unkown variables", {
3838
)
3939
})
4040

41+
test_that("draws() works when gradually adding variables", {
42+
skip_on_cran()
43+
fit <- testing_fit("logistic", method = "sample", refresh = 0,
44+
validate_csv = TRUE, save_warmup = TRUE)
45+
46+
draws_lp__ <- fit$draws(variables = c("lp__"), inc_warmup = TRUE)
47+
sampler_diagnostics <- fit$sampler_diagnostics(inc_warmup = TRUE)
48+
expect_type(draws_lp__, "double")
49+
expect_s3_class(draws_lp__, "draws_array")
50+
expect_equal(posterior::variables(draws_lp__), c("lp__"))
51+
expect_type(sampler_diagnostics, "double")
52+
expect_s3_class(sampler_diagnostics, "draws_array")
53+
expect_equal(posterior::variables(sampler_diagnostics), c(c("treedepth__", "divergent__", "accept_stat__", "stepsize__", "n_leapfrog__", "energy__")))
54+
draws_alpha <- fit$draws(variables = c("alpha"), inc_warmup = TRUE)
55+
expect_type(draws_alpha, "double")
56+
expect_s3_class(draws_alpha, "draws_array")
57+
expect_equal(posterior::variables(draws_alpha), c("alpha"))
58+
draws_beta <- fit$draws(variables = c("beta"), inc_warmup = TRUE)
59+
expect_type(draws_beta, "double")
60+
expect_s3_class(draws_beta, "draws_array")
61+
expect_equal(posterior::variables(draws_beta), c("beta[1]", "beta[2]", "beta[3]"))
62+
})
63+
4164
test_that("draws() method returns draws_array (reading csv works)", {
4265
skip_on_cran()
4366
draws <- fit_mcmc$draws()
@@ -273,3 +296,4 @@ test_that("loo errors if it can't find log lik variables", {
273296
fixed = TRUE
274297
)
275298
})
299+

tests/testthat/test-utils.R

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,18 @@ test_that("check_divergences() works", {
1111
csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv"))
1212
csv_output <- read_cmdstan_csv(csv_files)
1313
output <- "14 of 100 \\(14.0%\\) transitions ended with a divergence."
14-
expect_message(check_divergences(csv_output), output)
14+
expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output)
1515

1616
csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv"),
1717
test_path("resources", "csv", "model1-2-no-warmup.csv"))
1818
csv_output <- read_cmdstan_csv(csv_files)
1919
output <- "28 of 200 \\(14.0%\\) transitions ended with a divergence."
20-
expect_message(check_divergences(csv_output), output)
20+
expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output)
2121

2222
csv_files <- c(test_path("resources", "csv", "model1-2-warmup.csv"))
2323
csv_output <- read_cmdstan_csv(csv_files)
2424
output <- "1 of 100 \\(1.0%\\) transitions ended with a divergence."
25-
expect_message(check_divergences(csv_output), output)
25+
expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), output)
2626

2727

2828
fit_wramup_no_samples <- testing_fit("logistic", method = "sample",
@@ -32,27 +32,41 @@ test_that("check_divergences() works", {
3232
save_warmup = TRUE,
3333
validate_csv = FALSE)
3434
csv_output <- read_cmdstan_csv(fit_wramup_no_samples$output_files())
35-
expect_message(check_divergences(csv_output), regexp = NA)
35+
expect_message(check_divergences(csv_output$post_warmup_sampler_diagnostics), regexp = NA)
3636
})
3737

3838
test_that("check_sampler_transitions_treedepth() works", {
3939
skip_on_cran()
4040
csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv"))
4141
csv_output <- read_cmdstan_csv(csv_files)
4242
output <- "16 of 100 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps."
43-
expect_message(check_sampler_transitions_treedepth(csv_output), output)
43+
expect_message(
44+
check_sampler_transitions_treedepth(
45+
csv_output$post_warmup_sampler_diagnostics,
46+
csv_output$metadata),
47+
output
48+
)
4449

4550
csv_files <- c(test_path("resources", "csv", "model1-2-no-warmup.csv"),
4651
test_path("resources", "csv", "model1-2-no-warmup.csv"))
4752
csv_output <- read_cmdstan_csv(csv_files)
4853
output <- "32 of 200 \\(16.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps."
49-
expect_message(check_sampler_transitions_treedepth(csv_output), output)
54+
expect_message(
55+
check_sampler_transitions_treedepth(
56+
csv_output$post_warmup_sampler_diagnostics,
57+
csv_output$metadata),
58+
output
59+
)
5060

5161
csv_files <- c(test_path("resources", "csv", "model1-2-warmup.csv"))
5262
csv_output <- read_cmdstan_csv(csv_files)
5363
output <- "1 of 100 \\(1.0%\\) transitions hit the maximum treedepth limit of 5 or 2\\^5-1 leapfrog steps."
54-
expect_message(check_sampler_transitions_treedepth(csv_output), output)
55-
64+
expect_message(
65+
check_sampler_transitions_treedepth(
66+
csv_output$post_warmup_sampler_diagnostics,
67+
csv_output$metadata),
68+
output
69+
)
5670
})
5771

5872
test_that("cmdstan_summary works if bin/stansummary deleted file", {

0 commit comments

Comments
 (0)