@@ -121,8 +121,6 @@ read_cmdstan_csv <- function(files,
121
121
col_types <- NULL
122
122
col_select <- NULL
123
123
not_matching <- c()
124
- vroom_warnings <- 0
125
-
126
124
for (output_file in files ) {
127
125
if (is.null(metadata )) {
128
126
metadata <- read_csv_metadata(output_file )
@@ -205,48 +203,22 @@ read_cmdstan_csv <- function(files,
205
203
} else if (metadata $ method == " optimize" ) {
206
204
all_draws <- 1
207
205
}
208
-
209
- vroom_args <- list (
210
- file = output_file ,
211
- comment = " #" ,
212
- delim = " ," ,
213
- trim_ws = TRUE ,
214
- altrep = FALSE ,
215
- progress = FALSE ,
216
- skip = metadata $ lines_to_skip ,
217
- col_select = col_select ,
218
- num_threads = 1
219
- )
220
- if (metadata $ method == " generate_quantities" ) {
221
- # set the first arg as double to silence the type detection info
222
- vroom_args $ col_types <- list ()
223
- vroom_args $ col_types [[col_select [1 ]]] <- " d"
224
- } else {
225
- vroom_args $ col_types <- c(" lp__" = " d" )
226
- vroom_args $ n_max <- all_draws * 2
227
- }
228
-
229
- draws <- try(silent = TRUE , expr = {
230
- suppressWarnings(do.call(vroom :: vroom , vroom_args ))
231
- })
232
- if (! inherits(draws , " try-error" )) {
233
- if (metadata $ method != " generate_quantities" ) {
234
- draws <- draws [! is.na(draws $ lp__ ), ]
206
+ if (length(col_select ) > 0 ) {
207
+ if (os_is_windows()) {
208
+ grep_path <- repair_path(Sys.which(" grep.exe" ))
209
+ fread_cmd <- paste0(grep_path , " -v '^#' " , output_file )
210
+ } else {
211
+ fread_cmd <- paste0(" grep -v '^#' " , output_file )
235
212
}
236
- } else {
237
- if (vroom_warnings == 0 ) { # only warn the first time instead of for every csv file
238
- warning(
239
- " Fast CSV reading with vroom::vroom() failed. Using utils::read.csv() instead. " ,
240
- " \n To help avoid this in the future, please report this issue at github.com/stan-dev/cmdstanr/issues " ,
241
- " and include the output from sessionInfo(). Thank you!" ,
242
- call. = FALSE
213
+ suppressWarnings(
214
+ draws <- data.table :: fread(
215
+ cmd = fread_cmd ,
216
+ select = col_select
243
217
)
244
- }
245
- vroom_warnings <- vroom_warnings + 1
246
- draws <- utils :: read.csv(output_file , comment.char = " #" , skip = metadata $ lines_to_skip )
247
- draws <- draws [, col_select ]
218
+ )
219
+ } else {
220
+ draws <- NULL
248
221
}
249
-
250
222
if (nrow(draws ) > 0 ) {
251
223
if (metadata $ method == " sample" ) {
252
224
if (metadata $ save_warmup == 1 ) {
@@ -316,7 +288,6 @@ read_cmdstan_csv <- function(files,
316
288
}
317
289
}
318
290
}
319
-
320
291
if (length(not_matching ) > 0 ) {
321
292
not_matching_list <- paste(unique(not_matching ), collapse = " , " )
322
293
warning(" Supplied CSV files do not match in the following arguments: " ,
@@ -411,102 +382,93 @@ read_sample_csv <- function(files,
411
382
# '
412
383
read_csv_metadata <- function (csv_file ) {
413
384
checkmate :: assert_file_exists(csv_file , access = " r" , extension = " csv" )
414
- con <- file(csv_file , open = " r" )
415
385
adaptation_terminated <- FALSE
416
386
param_names_read <- FALSE
417
387
inv_metric_next <- FALSE
418
388
inv_metric_diagonal_next <- FALSE
419
389
csv_file_info <- list ()
420
- csv_file_info [[" inv_metric" ]] <- NULL
421
390
inv_metric_rows <- 0
422
391
parsing_done <- FALSE
423
- lines_before_param_names <- 0
424
- while (length(line <- readLines(con , n = 1 , warn = FALSE )) > 0 && ! parsing_done ) {
425
- if (! startsWith(line , " #" )) {
426
- if (! param_names_read ) {
427
- param_names_read <- TRUE
428
- all_names <- strsplit(line , " ," )[[1 ]]
429
- csv_file_info [[" sampler_diagnostics" ]] <- c()
430
- csv_file_info [[" model_params" ]] <- c()
431
- for (x in all_names ) {
432
- if (all(csv_file_info $ algorithm != " fixed_param" )) {
433
- if (endsWith(x , " __" ) && ! (x %in% c(" lp__" , " log_p__" , " log_g__" ))) {
434
- csv_file_info [[" sampler_diagnostics" ]] <- c(csv_file_info [[" sampler_diagnostics" ]], x )
435
- } else {
436
- csv_file_info [[" model_params" ]] <- c(csv_file_info [[" model_params" ]], x )
437
- }
438
- } else {
439
- if (! endsWith(x , " __" )) {
440
- csv_file_info [[" model_params" ]] <- c(csv_file_info [[" model_params" ]], x )
441
- }
442
- }
443
- }
392
+ if (os_is_windows()) {
393
+ grep_path <- repair_path(Sys.which(" grep.exe" ))
394
+ fread_cmd <- paste0(grep_path , " '^[#a-zA-Z]' " , csv_file )
395
+ } else {
396
+ fread_cmd <- paste0(" grep '^[#a-zA-Z]' " , csv_file )
397
+ }
398
+ suppressWarnings(
399
+ metadata <- data.table :: fread(
400
+ cmd = fread_cmd ,
401
+ colClasses = " character" ,
402
+ stringsAsFactors = FALSE ,
403
+ fill = TRUE ,
404
+ sep = " " ,
405
+ header = FALSE
406
+ )
407
+ )
408
+ if (is.null(metadata ) || length(metadata ) == 0 ) {
409
+ stop(" Supplied CSV file is corrupt!" , call. = FALSE )
410
+ }
411
+ for (line in metadata [[1 ]]) {
412
+ if (! startsWith(line , " #" ) && is.null(csv_file_info [[" model_params" ]])) {
413
+ # if no # at the start of line, the line is the CSV header
414
+ all_names <- strsplit(line , " ," )[[1 ]]
415
+ if (all(csv_file_info $ algorithm != " fixed_param" )) {
416
+ csv_file_info [[" sampler_diagnostics" ]] <- all_names [endsWith(all_names , " __" )]
417
+ csv_file_info [[" sampler_diagnostics" ]] <- csv_file_info [[" sampler_diagnostics" ]][! (csv_file_info [[" sampler_diagnostics" ]] %in% c(" lp__" , " log_p__" , " log_g__" ))]
418
+ csv_file_info [[" model_params" ]] <- all_names [! (all_names %in% csv_file_info [[" sampler_diagnostics" ]])]
419
+ } else {
420
+ csv_file_info [[" model_params" ]] <- all_names [! endsWith(all_names , " __" )]
444
421
}
445
422
} else {
446
- if (! param_names_read ) {
447
- lines_before_param_names <- lines_before_param_names + 1
448
- }
449
- if (! adaptation_terminated ) {
450
- if (regexpr(" # Adaptation terminated" , line , perl = TRUE ) > 0 ) {
451
- adaptation_terminated <- TRUE
452
- } else {
453
- tmp <- gsub(" #" , " " , line , fixed = TRUE )
454
- tmp <- gsub(" (Default)" , " " , tmp , fixed = TRUE )
455
- key_val <- grep(" =" , tmp , fixed = TRUE , value = TRUE )
456
- key_val <- strsplit(key_val , split = " =" , fixed = TRUE )
457
- key_val <- rapply(key_val , trimws )
458
- if (length(key_val ) == 2 ) {
459
- numeric_val <- suppressWarnings(as.numeric(key_val [2 ]))
460
- if (! is.na(numeric_val )) {
461
- csv_file_info [[key_val [1 ]]] <- numeric_val
462
- } else {
463
- if (nzchar(key_val [2 ])) {
464
- csv_file_info [[key_val [1 ]]] <- key_val [2 ]
465
- }
466
- }
467
- }
423
+ parse_key_val <- TRUE
424
+ if (regexpr(" # Diagonal elements of inverse mass matrix:" , line , perl = TRUE ) > 0
425
+ || regexpr(" # Elements of inverse mass matrix:" , line , perl = TRUE ) > 0 ) {
426
+ inv_metric_next <- TRUE
427
+ parse_key_val <- FALSE
428
+ } else if (inv_metric_next ) {
429
+ inv_metric_split <- strsplit(gsub(" # " , " " , line ), " ," )
430
+ if ((length(inv_metric_split ) == 0 ) ||
431
+ ((length(inv_metric_split ) == 1 ) && identical(inv_metric_split [[1 ]], character (0 ))) ||
432
+ regexpr(" [a-zA-z]" , line , perl = TRUE ) > 0 ||
433
+ inv_metric_split == " #" ) {
434
+ parsing_done <- TRUE
435
+ parse_key_val <- TRUE
436
+ break ;
468
437
}
469
- } else {
470
- # after adaptation terminated read in the step size and inverse metrics
471
- if (regexpr(" # Step size = " , line , perl = TRUE ) > 0 ) {
472
- csv_file_info $ stepsize_adaptation <- as.numeric(strsplit(line , " = " )[[1 ]][2 ])
473
- } else if (regexpr(" # Diagonal elements of inverse mass matrix:" , line , perl = TRUE ) > 0 ) {
474
- inv_metric_diagonal_next <- TRUE
475
- } else if (regexpr(" # Elements of inverse mass matrix:" , line , perl = TRUE ) > 0 ){
476
- inv_metric_next <- TRUE
477
- } else if (inv_metric_diagonal_next ) {
478
- inv_metric_split <- strsplit(gsub(" # " , " " , line ), " ," )
479
- if ((length(inv_metric_split ) == 0 ) ||
480
- ((length(inv_metric_split ) == 1 ) && identical(inv_metric_split [[1 ]], character (0 )))) {
481
- break ;
482
- }
438
+ if (inv_metric_rows == 0 ) {
483
439
csv_file_info $ inv_metric <- rapply(inv_metric_split , as.numeric )
484
- parsing_done <- TRUE
485
- } else if (inv_metric_next ) {
486
- inv_metric_split <- strsplit(gsub(" # " , " " , line ), " ," )
487
- if ((length(inv_metric_split ) == 0 ) ||
488
- ((length(inv_metric_split ) == 1 ) && identical(inv_metric_split [[1 ]], character (0 )))) {
489
- parsing_done <- TRUE
490
- break ;
491
- }
492
- if (inv_metric_rows == 0 ) {
493
- csv_file_info $ inv_metric <- rapply(inv_metric_split , as.numeric )
440
+ } else {
441
+ csv_file_info $ inv_metric <- c(csv_file_info $ inv_metric , rapply(inv_metric_split , as.numeric ))
442
+ }
443
+ inv_metric_rows <- inv_metric_rows + 1
444
+ parse_key_val <- FALSE
445
+ }
446
+ if (parse_key_val ) {
447
+ tmp <- gsub(" #" , " " , line , fixed = TRUE )
448
+ tmp <- gsub(" (Default)" , " " , tmp , fixed = TRUE )
449
+ key_val <- grep(" =" , tmp , fixed = TRUE , value = TRUE )
450
+ key_val <- strsplit(key_val , split = " =" , fixed = TRUE )
451
+ key_val <- rapply(key_val , trimws )
452
+ if (any(key_val [1 ] == " Step size" )) {
453
+ key_val [1 ] <- " step_size_adaptation"
454
+ }
455
+ if (length(key_val ) == 2 ) {
456
+ numeric_val <- suppressWarnings(as.numeric(key_val [2 ]))
457
+ if (! is.na(numeric_val )) {
458
+ csv_file_info [[key_val [1 ]]] <- numeric_val
494
459
} else {
495
- csv_file_info $ inv_metric <- c(csv_file_info $ inv_metric , rapply(inv_metric_split , as.numeric ))
460
+ if (nzchar(key_val [2 ])) {
461
+ csv_file_info [[key_val [1 ]]] <- key_val [2 ]
462
+ }
496
463
}
497
- inv_metric_rows <- inv_metric_rows + 1
498
464
}
499
465
}
500
466
}
501
467
}
502
- close(con )
503
- if (is.null(csv_file_info $ method )) {
504
- stop(" Supplied CSV file is corrupt!" , call. = FALSE )
505
- }
506
468
if (length(csv_file_info $ sampler_diagnostics ) == 0 && length(csv_file_info $ model_params ) == 0 ) {
507
469
stop(" Supplied CSV file does not contain any variable names or data!" , call. = FALSE )
508
470
}
509
- if (inv_metric_rows > 0 ) {
471
+ if (inv_metric_rows > 0 && csv_file_info $ metric == " dense_e " ) {
510
472
rows <- inv_metric_rows
511
473
cols <- length(csv_file_info $ inv_metric )/ inv_metric_rows
512
474
dim(csv_file_info $ inv_metric ) <- c(rows ,cols )
@@ -518,7 +480,6 @@ read_csv_metadata <- function(csv_file) {
518
480
csv_file_info $ adapt_delta <- csv_file_info $ delta
519
481
csv_file_info $ max_treedepth <- csv_file_info $ max_depth
520
482
csv_file_info $ step_size <- csv_file_info $ stepsize
521
- csv_file_info $ step_size_adaptation <- csv_file_info $ stepsize_adaptation
522
483
csv_file_info $ iter_warmup <- csv_file_info $ num_warmup
523
484
csv_file_info $ iter_sampling <- csv_file_info $ num_samples
524
485
csv_file_info $ threads_per_chain <- csv_file_info $ num_threads
@@ -527,14 +488,12 @@ read_csv_metadata <- function(csv_file) {
527
488
csv_file_info $ delta <- NULL
528
489
csv_file_info $ max_depth <- NULL
529
490
csv_file_info $ stepsize <- NULL
530
- csv_file_info $ stepsize_adaptation <- NULL
531
491
csv_file_info $ num_warmup <- NULL
532
492
csv_file_info $ num_samples <- NULL
533
493
csv_file_info $ file <- NULL
534
494
csv_file_info $ diagnostic_file <- NULL
535
495
csv_file_info $ metric_file <- NULL
536
496
csv_file_info $ num_threads <- NULL
537
- csv_file_info $ lines_to_skip <- lines_before_param_names
538
497
539
498
csv_file_info
540
499
}
0 commit comments