@@ -120,11 +120,11 @@ read_cmdstan_csv <- function(files,
120
120
sampler_diagnostics = NULL ) {
121
121
checkmate :: assert_file_exists(files , access = " r" , extension = " csv" )
122
122
metadata <- NULL
123
- warmup_draws <- NULL
124
- warmup_sampler_diagnostics_draws <- NULL
125
- post_warmup_draws <- NULL
126
- post_warmup_sampler_diagnostics_draws <- NULL
127
- generated_quantities <- NULL
123
+ warmup_draws <- list ()
124
+ post_warmup_draws <- list ()
125
+ warmup_sampler_diagnostics_draws <- list ()
126
+ post_warmup_sampler_diagnostics_draws <- list ()
127
+ generated_quantities <- list ()
128
128
variational_draws <- NULL
129
129
point_estimates <- NULL
130
130
inv_metric <- list ()
@@ -241,49 +241,25 @@ read_cmdstan_csv <- function(files,
241
241
if (metadata $ method == " sample" ) {
242
242
if (metadata $ save_warmup == 1 ) {
243
243
if (length(variables ) > 0 ) {
244
- warmup_draws <- posterior :: bind_draws(
245
- warmup_draws ,
246
- posterior :: as_draws_array(draws [1 : num_warmup_draws , variables , drop = FALSE ]),
247
- along = " chain"
248
- )
244
+ warmup_draws [[length(warmup_draws ) + 1 ]] <- draws [1 : num_warmup_draws , variables , drop = FALSE ]
249
245
if (num_post_warmup_draws > 0 ) {
250
- post_warmup_draws <- posterior :: bind_draws(
251
- post_warmup_draws ,
252
- posterior :: as_draws_array(draws [(num_warmup_draws + 1 ): all_draws , variables , drop = FALSE ]),
253
- along = " chain"
254
- )
246
+ post_warmup_draws [[length(post_warmup_draws ) + 1 ]] <- draws [(num_warmup_draws + 1 ): all_draws , variables , drop = FALSE ]
255
247
}
256
248
}
257
249
if (length(sampler_diagnostics ) > 0 ) {
258
- warmup_sampler_diagnostics_draws <- posterior :: bind_draws(
259
- warmup_sampler_diagnostics_draws ,
260
- posterior :: as_draws_array(draws [1 : num_warmup_draws , sampler_diagnostics , drop = FALSE ]),
261
- along = " chain"
262
- )
250
+ warmup_sampler_diagnostics_draws [[length(warmup_sampler_diagnostics_draws ) + 1 ]] <- draws [1 : num_warmup_draws , sampler_diagnostics , drop = FALSE ]
263
251
if (num_post_warmup_draws > 0 ) {
264
- post_warmup_sampler_diagnostics_draws <- posterior :: bind_draws(
265
- post_warmup_sampler_diagnostics_draws ,
266
- posterior :: as_draws_array(draws [(num_warmup_draws + 1 ): all_draws , sampler_diagnostics , drop = FALSE ]),
267
- along = " chain"
268
- )
252
+ post_warmup_sampler_diagnostics_draws [[length(post_warmup_sampler_diagnostics_draws ) + 1 ]] <- draws [(num_warmup_draws + 1 ): all_draws , sampler_diagnostics , drop = FALSE ]
269
253
}
270
254
}
271
255
} else {
272
256
warmup_draws <- NULL
273
257
warmup_sampler_diagnostics_draws <- NULL
274
258
if (length(variables ) > 0 ) {
275
- post_warmup_draws <- posterior :: bind_draws(
276
- post_warmup_draws ,
277
- posterior :: as_draws_array(draws [, variables , drop = FALSE ]),
278
- along = " chain"
279
- )
259
+ post_warmup_draws [[length(post_warmup_draws ) + 1 ]] <- draws [, variables , drop = FALSE ]
280
260
}
281
261
if (length(sampler_diagnostics ) > 0 && all(metadata $ algorithm != " fixed_param" )) {
282
- post_warmup_sampler_diagnostics_draws <- posterior :: bind_draws(
283
- post_warmup_sampler_diagnostics_draws ,
284
- posterior :: as_draws_array(draws [, sampler_diagnostics , drop = FALSE ]),
285
- along = " chain"
286
- )
262
+ post_warmup_sampler_diagnostics_draws [[length(post_warmup_sampler_diagnostics_draws ) + 1 ]] <- draws [, sampler_diagnostics , drop = FALSE ]
287
263
}
288
264
}
289
265
} else if (metadata $ method == " variational" ) {
@@ -300,9 +276,7 @@ read_cmdstan_csv <- function(files,
300
276
} else if (metadata $ method == " optimize" ) {
301
277
point_estimates <- posterior :: as_draws_matrix(draws [1 ,, drop = FALSE ])[, variables ]
302
278
} else if (metadata $ method == " generate_quantities" ) {
303
- generated_quantities <- posterior :: bind_draws(generated_quantities ,
304
- posterior :: as_draws_array(draws ),
305
- along = " chain" )
279
+ generated_quantities [[length(generated_quantities ) + 1 ]] <- draws
306
280
}
307
281
}
308
282
}
@@ -313,7 +287,6 @@ read_cmdstan_csv <- function(files,
313
287
}
314
288
315
289
metadata $ inv_metric <- NULL
316
- metadata $ lines_to_skip <- NULL
317
290
metadata $ model_params <- repair_variable_names(metadata $ model_params )
318
291
repaired_variables <- repair_variable_names(variables )
319
292
if (metadata $ method == " variational" ) {
@@ -330,12 +303,16 @@ read_cmdstan_csv <- function(files,
330
303
metadata $ stan_variables <- names(model_param_dims )
331
304
332
305
if (metadata $ method == " sample" ) {
306
+ warmup_draws <- bind_list_of_draws_array(warmup_draws )
333
307
if (! is.null(warmup_draws )) {
334
308
posterior :: variables(warmup_draws ) <- repaired_variables
335
309
}
310
+ post_warmup_draws <- bind_list_of_draws_array(post_warmup_draws )
336
311
if (! is.null(post_warmup_draws )) {
337
312
posterior :: variables(post_warmup_draws ) <- repaired_variables
338
313
}
314
+ warmup_sampler_diagnostics_draws <- bind_list_of_draws_array(warmup_sampler_diagnostics_draws )
315
+ post_warmup_sampler_diagnostics_draws <- bind_list_of_draws_array(post_warmup_sampler_diagnostics_draws )
339
316
list (
340
317
metadata = metadata ,
341
318
time = list (total = NA_integer_ , chains = time ),
@@ -363,6 +340,7 @@ read_cmdstan_csv <- function(files,
363
340
point_estimates = point_estimates
364
341
)
365
342
} else if (metadata $ method == " generate_quantities" ) {
343
+ generated_quantities <- bind_list_of_draws_array(generated_quantities )
366
344
if (! is.null(generated_quantities )) {
367
345
posterior :: variables(generated_quantities ) <- repaired_variables
368
346
}
@@ -422,8 +400,8 @@ CmdStanMCMC_CSV <- R6::R6Class(
422
400
public = list (
423
401
initialize = function (csv_contents , files , check_diagnostics = TRUE ) {
424
402
if (check_diagnostics ) {
425
- check_divergences(csv_contents )
426
- check_sampler_transitions_treedepth(csv_contents )
403
+ check_divergences(csv_contents $ post_warmup_sampler_diagnostics )
404
+ check_sampler_transitions_treedepth(csv_contents $ post_warmup_sampler_diagnostics , csv_contents $ metadata )
427
405
}
428
406
private $ output_files_ <- files
429
407
private $ metadata_ <- csv_contents $ metadata
@@ -708,7 +686,20 @@ check_csv_metadata_matches <- function(a, b) {
708
686
list (not_matching = not_matching )
709
687
}
710
688
711
-
689
+ bind_list_of_draws_array <- function (draws , along = " chain" ) {
690
+ if (! is.null(draws ) && length(draws ) > 0 ) {
691
+ if (length(draws ) > 1 ) {
692
+ draws <- lapply(draws , posterior :: as_draws_array )
693
+ draws [[" along" ]] <- along
694
+ draws <- do.call(posterior :: bind_draws , draws )
695
+ } else {
696
+ draws <- posterior :: as_draws_array(draws [[1 ]])
697
+ }
698
+ } else {
699
+ draws <- NULL
700
+ }
701
+ draws
702
+ }
712
703
713
704
# convert names like beta.1.1 to beta[1,1]
714
705
repair_variable_names <- function (names ) {
0 commit comments