Skip to content

Commit baed002

Browse files
authored
Merge pull request #318 from stan-dev/fread
Replace vroom with data.table::fread
2 parents 962e7c1 + e1d1c60 commit baed002

File tree

7 files changed

+109
-140
lines changed

7 files changed

+109
-140
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Imports:
3030
posterior (>= 0.1.0),
3131
processx,
3232
R6 (>= 2.4.0),
33-
vroom
33+
data.table
3434
Suggests:
3535
bayesplot,
3636
knitr,

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ specifying custom chain IDs. (#319)
3535

3636
* Added support for the `sig_figs` argument in CmdStan versions 2.25 and above. (#327)
3737

38+
* CSV reading is now faster by using `data.table::fread()`. (#318)
39+
3840
# cmdstanr 0.1.3
3941

4042
* New `$check_syntax()` method for CmdStanModel objects. (#276, #277)

R/data.R

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,19 +162,22 @@ process_fitted_params <- function(fitted_params) {
162162
paths <- file.path(tempdir(), paths)
163163
chain <- 1
164164
for (path in paths) {
165-
chain_draws <- posterior::as_draws_df(posterior::subset_draws(draws, chain = chain))
166-
colnames(chain_draws) <- unrepair_variable_names(variables)
165+
chain_draws <- posterior::subset_draws(draws, chain = chain)
167166
write(
168167
paste0("# num_samples = ", iterations),
168+
file = path
169+
)
170+
write(
171+
paste0(unrepair_variable_names(variables), collapse = ","),
169172
file = path,
170-
append = FALSE
173+
append = TRUE
171174
)
172-
vroom::vroom_write(
175+
utils::write.table(
173176
chain_draws,
174-
delim = ",",
175-
path = path,
176-
col_names = TRUE,
177-
progress = FALSE,
177+
file = path,
178+
sep = ",",
179+
col.names = FALSE,
180+
row.names = FALSE,
178181
append = TRUE
179182
)
180183
chain <- chain + 1

R/read_csv.R

Lines changed: 80 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,6 @@ read_cmdstan_csv <- function(files,
121121
col_types <- NULL
122122
col_select <- NULL
123123
not_matching <- c()
124-
vroom_warnings <- 0
125-
126124
for (output_file in files) {
127125
if (is.null(metadata)) {
128126
metadata <- read_csv_metadata(output_file)
@@ -205,48 +203,22 @@ read_cmdstan_csv <- function(files,
205203
} else if (metadata$method == "optimize") {
206204
all_draws <- 1
207205
}
208-
209-
vroom_args <- list(
210-
file = output_file,
211-
comment = "#",
212-
delim = ",",
213-
trim_ws = TRUE,
214-
altrep = FALSE,
215-
progress = FALSE,
216-
skip = metadata$lines_to_skip,
217-
col_select = col_select,
218-
num_threads = 1
219-
)
220-
if (metadata$method == "generate_quantities") {
221-
# set the first arg as double to silence the type detection info
222-
vroom_args$col_types <- list()
223-
vroom_args$col_types[[col_select[1]]] <- "d"
224-
} else {
225-
vroom_args$col_types <- c("lp__" = "d")
226-
vroom_args$n_max <- all_draws * 2
227-
}
228-
229-
draws <- try(silent = TRUE, expr = {
230-
suppressWarnings(do.call(vroom::vroom, vroom_args))
231-
})
232-
if (!inherits(draws, "try-error")) {
233-
if (metadata$method != "generate_quantities") {
234-
draws <- draws[!is.na(draws$lp__), ]
206+
if (length(col_select) > 0) {
207+
if (os_is_windows()) {
208+
grep_path <- repair_path(Sys.which("grep.exe"))
209+
fread_cmd <- paste0(grep_path, " -v '^#' ", output_file)
210+
} else {
211+
fread_cmd <- paste0("grep -v '^#' ", output_file)
235212
}
236-
} else {
237-
if (vroom_warnings == 0) { # only warn the first time instead of for every csv file
238-
warning(
239-
"Fast CSV reading with vroom::vroom() failed. Using utils::read.csv() instead. ",
240-
"\nTo help avoid this in the future, please report this issue at github.com/stan-dev/cmdstanr/issues ",
241-
"and include the output from sessionInfo(). Thank you!",
242-
call. = FALSE
213+
suppressWarnings(
214+
draws <- data.table::fread(
215+
cmd = fread_cmd,
216+
select = col_select
243217
)
244-
}
245-
vroom_warnings <- vroom_warnings + 1
246-
draws <- utils::read.csv(output_file, comment.char = "#", skip = metadata$lines_to_skip)
247-
draws <- draws[, col_select]
218+
)
219+
} else {
220+
draws <- NULL
248221
}
249-
250222
if (nrow(draws) > 0) {
251223
if (metadata$method == "sample") {
252224
if (metadata$save_warmup == 1) {
@@ -316,7 +288,6 @@ read_cmdstan_csv <- function(files,
316288
}
317289
}
318290
}
319-
320291
if (length(not_matching) > 0) {
321292
not_matching_list <- paste(unique(not_matching), collapse = ", ")
322293
warning("Supplied CSV files do not match in the following arguments: ",
@@ -411,102 +382,93 @@ read_sample_csv <- function(files,
411382
#'
412383
read_csv_metadata <- function(csv_file) {
413384
checkmate::assert_file_exists(csv_file, access = "r", extension = "csv")
414-
con <- file(csv_file, open = "r")
415385
adaptation_terminated <- FALSE
416386
param_names_read <- FALSE
417387
inv_metric_next <- FALSE
418388
inv_metric_diagonal_next <- FALSE
419389
csv_file_info <- list()
420-
csv_file_info[["inv_metric"]] <- NULL
421390
inv_metric_rows <- 0
422391
parsing_done <- FALSE
423-
lines_before_param_names <- 0
424-
while (length(line <- readLines(con, n = 1, warn = FALSE)) > 0 && !parsing_done) {
425-
if (!startsWith(line, "#")) {
426-
if (!param_names_read) {
427-
param_names_read <- TRUE
428-
all_names <- strsplit(line, ",")[[1]]
429-
csv_file_info[["sampler_diagnostics"]] <- c()
430-
csv_file_info[["model_params"]] <- c()
431-
for (x in all_names) {
432-
if (all(csv_file_info$algorithm != "fixed_param")) {
433-
if (endsWith(x, "__") && !(x %in% c("lp__", "log_p__", "log_g__"))) {
434-
csv_file_info[["sampler_diagnostics"]] <- c(csv_file_info[["sampler_diagnostics"]], x)
435-
} else {
436-
csv_file_info[["model_params"]] <- c(csv_file_info[["model_params"]], x)
437-
}
438-
} else {
439-
if (!endsWith(x, "__")) {
440-
csv_file_info[["model_params"]] <- c(csv_file_info[["model_params"]], x)
441-
}
442-
}
443-
}
392+
if (os_is_windows()) {
393+
grep_path <- repair_path(Sys.which("grep.exe"))
394+
fread_cmd <- paste0(grep_path, " '^[#a-zA-Z]' ", csv_file)
395+
} else {
396+
fread_cmd <- paste0("grep '^[#a-zA-Z]' ", csv_file)
397+
}
398+
suppressWarnings(
399+
metadata <- data.table::fread(
400+
cmd = fread_cmd,
401+
colClasses = "character",
402+
stringsAsFactors = FALSE,
403+
fill = TRUE,
404+
sep = "",
405+
header= FALSE
406+
)
407+
)
408+
if (is.null(metadata) || length(metadata) == 0) {
409+
stop("Supplied CSV file is corrupt!", call. = FALSE)
410+
}
411+
for (line in metadata[[1]]) {
412+
if (!startsWith(line, "#") && is.null(csv_file_info[["model_params"]])) {
413+
# if no # at the start of line, the line is the CSV header
414+
all_names <- strsplit(line, ",")[[1]]
415+
if (all(csv_file_info$algorithm != "fixed_param")) {
416+
csv_file_info[["sampler_diagnostics"]] <- all_names[endsWith(all_names, "__")]
417+
csv_file_info[["sampler_diagnostics"]] <- csv_file_info[["sampler_diagnostics"]][!(csv_file_info[["sampler_diagnostics"]] %in% c("lp__", "log_p__", "log_g__"))]
418+
csv_file_info[["model_params"]] <- all_names[!(all_names %in% csv_file_info[["sampler_diagnostics"]])]
419+
} else {
420+
csv_file_info[["model_params"]] <- all_names[!endsWith(all_names, "__")]
444421
}
445422
} else {
446-
if (!param_names_read) {
447-
lines_before_param_names <- lines_before_param_names + 1
448-
}
449-
if (!adaptation_terminated) {
450-
if (regexpr("# Adaptation terminated", line, perl = TRUE) > 0) {
451-
adaptation_terminated <- TRUE
452-
} else {
453-
tmp <- gsub("#", "", line, fixed = TRUE)
454-
tmp <- gsub("(Default)", "", tmp, fixed = TRUE)
455-
key_val <- grep("=", tmp, fixed = TRUE, value = TRUE)
456-
key_val <- strsplit(key_val, split = "=", fixed = TRUE)
457-
key_val <- rapply(key_val, trimws)
458-
if (length(key_val) == 2) {
459-
numeric_val <- suppressWarnings(as.numeric(key_val[2]))
460-
if (!is.na(numeric_val)) {
461-
csv_file_info[[key_val[1]]] <- numeric_val
462-
} else {
463-
if (nzchar(key_val[2])) {
464-
csv_file_info[[key_val[1]]] <- key_val[2]
465-
}
466-
}
467-
}
423+
parse_key_val <- TRUE
424+
if (regexpr("# Diagonal elements of inverse mass matrix:", line, perl = TRUE) > 0
425+
|| regexpr("# Elements of inverse mass matrix:", line, perl = TRUE) > 0) {
426+
inv_metric_next <- TRUE
427+
parse_key_val <- FALSE
428+
} else if (inv_metric_next) {
429+
inv_metric_split <- strsplit(gsub("# ", "", line), ",")
430+
if ((length(inv_metric_split) == 0) ||
431+
((length(inv_metric_split) == 1) && identical(inv_metric_split[[1]], character(0))) ||
432+
regexpr("[a-zA-z]", line, perl = TRUE) > 0 ||
433+
inv_metric_split == "#") {
434+
parsing_done <- TRUE
435+
parse_key_val <- TRUE
436+
break;
468437
}
469-
} else {
470-
# after adaptation terminated read in the step size and inverse metrics
471-
if (regexpr("# Step size = ", line, perl = TRUE) > 0) {
472-
csv_file_info$stepsize_adaptation <- as.numeric(strsplit(line, " = ")[[1]][2])
473-
} else if (regexpr("# Diagonal elements of inverse mass matrix:", line, perl = TRUE) > 0) {
474-
inv_metric_diagonal_next <- TRUE
475-
} else if (regexpr("# Elements of inverse mass matrix:", line, perl = TRUE) > 0){
476-
inv_metric_next <- TRUE
477-
} else if (inv_metric_diagonal_next) {
478-
inv_metric_split <- strsplit(gsub("# ", "", line), ",")
479-
if ((length(inv_metric_split) == 0) ||
480-
((length(inv_metric_split) == 1) && identical(inv_metric_split[[1]], character(0)))) {
481-
break;
482-
}
438+
if (inv_metric_rows == 0) {
483439
csv_file_info$inv_metric <- rapply(inv_metric_split, as.numeric)
484-
parsing_done <- TRUE
485-
} else if (inv_metric_next) {
486-
inv_metric_split <- strsplit(gsub("# ", "", line), ",")
487-
if ((length(inv_metric_split) == 0) ||
488-
((length(inv_metric_split) == 1) && identical(inv_metric_split[[1]], character(0)))) {
489-
parsing_done <- TRUE
490-
break;
491-
}
492-
if (inv_metric_rows == 0) {
493-
csv_file_info$inv_metric <- rapply(inv_metric_split, as.numeric)
440+
} else {
441+
csv_file_info$inv_metric <- c(csv_file_info$inv_metric, rapply(inv_metric_split, as.numeric))
442+
}
443+
inv_metric_rows <- inv_metric_rows + 1
444+
parse_key_val <- FALSE
445+
}
446+
if (parse_key_val) {
447+
tmp <- gsub("#", "", line, fixed = TRUE)
448+
tmp <- gsub("(Default)", "", tmp, fixed = TRUE)
449+
key_val <- grep("=", tmp, fixed = TRUE, value = TRUE)
450+
key_val <- strsplit(key_val, split = "=", fixed = TRUE)
451+
key_val <- rapply(key_val, trimws)
452+
if (any(key_val[1] == "Step size")) {
453+
key_val[1] <- "step_size_adaptation"
454+
}
455+
if (length(key_val) == 2) {
456+
numeric_val <- suppressWarnings(as.numeric(key_val[2]))
457+
if (!is.na(numeric_val)) {
458+
csv_file_info[[key_val[1]]] <- numeric_val
494459
} else {
495-
csv_file_info$inv_metric <- c(csv_file_info$inv_metric, rapply(inv_metric_split, as.numeric))
460+
if (nzchar(key_val[2])) {
461+
csv_file_info[[key_val[1]]] <- key_val[2]
462+
}
496463
}
497-
inv_metric_rows <- inv_metric_rows + 1
498464
}
499465
}
500466
}
501467
}
502-
close(con)
503-
if (is.null(csv_file_info$method)) {
504-
stop("Supplied CSV file is corrupt!", call. = FALSE)
505-
}
506468
if (length(csv_file_info$sampler_diagnostics) == 0 && length(csv_file_info$model_params) == 0) {
507469
stop("Supplied CSV file does not contain any variable names or data!", call. = FALSE)
508470
}
509-
if (inv_metric_rows > 0) {
471+
if (inv_metric_rows > 0 && csv_file_info$metric == "dense_e") {
510472
rows <- inv_metric_rows
511473
cols <- length(csv_file_info$inv_metric)/inv_metric_rows
512474
dim(csv_file_info$inv_metric) <- c(rows,cols)
@@ -518,7 +480,6 @@ read_csv_metadata <- function(csv_file) {
518480
csv_file_info$adapt_delta <- csv_file_info$delta
519481
csv_file_info$max_treedepth <- csv_file_info$max_depth
520482
csv_file_info$step_size <- csv_file_info$stepsize
521-
csv_file_info$step_size_adaptation <- csv_file_info$stepsize_adaptation
522483
csv_file_info$iter_warmup <- csv_file_info$num_warmup
523484
csv_file_info$iter_sampling <- csv_file_info$num_samples
524485
csv_file_info$threads_per_chain <- csv_file_info$num_threads
@@ -527,14 +488,12 @@ read_csv_metadata <- function(csv_file) {
527488
csv_file_info$delta <- NULL
528489
csv_file_info$max_depth <- NULL
529490
csv_file_info$stepsize <- NULL
530-
csv_file_info$stepsize_adaptation <- NULL
531491
csv_file_info$num_warmup <- NULL
532492
csv_file_info$num_samples <- NULL
533493
csv_file_info$file <- NULL
534494
csv_file_info$diagnostic_file <- NULL
535495
csv_file_info$metric_file <- NULL
536496
csv_file_info$num_threads <- NULL
537-
csv_file_info$lines_to_skip <- lines_before_param_names
538497

539498
csv_file_info
540499
}

tests/testthat/test-data.R

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,16 @@ test_that("process_fitted_params() works if output_files in fit do not exist", {
8080
expect_true(all(file.exists(new_files)))
8181
chain <- 1
8282
for(file in new_files) {
83+
if (os_is_windows()) {
84+
grep_path <- repair_path(Sys.which("grep.exe"))
85+
fread_cmd <- paste0(grep_path, " -v '^#' ", file)
86+
} else {
87+
fread_cmd <- paste0("grep -v '^#' ", file)
88+
}
8389
suppressWarnings(
84-
tmp_file_gq <- vroom::vroom(
85-
file,
86-
comment = "#",
87-
delim = ',',
88-
trim_ws = TRUE,
89-
altrep = FALSE,
90-
progress = FALSE,
91-
skip = 1)
90+
tmp_file_gq <- data.table::fread(
91+
cmd = fread_cmd
92+
)
9293
)
9394
tmp_file_gq <- posterior::as_draws_array(tmp_file_gq)
9495
expect_equal(

tests/testthat/test-fit-gq.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ test_that("draws() method returns draws_array (reading csv works)", {
3131
draws_sum_y <- fit_gq$draws(variables = c("sum_y", "y_rep"))
3232
draws_y_sum <- fit_gq$draws(variables = c("y_rep", "sum_y"))
3333
draws_all_after <- fit_gq$draws()
34-
expect_type(draws, "double")
34+
expect_type(draws, "integer")
3535
expect_s3_class(draws, "draws_array")
3636
expect_equal(posterior::variables(draws), PARAM_NAMES)
3737
expect_equal(posterior::nchains(draws), fit_gq$num_chains())

tests/testthat/test-fit-shared.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,11 @@ test_that("draws() method returns a 'draws' object", {
138138
for (method in all_methods) {
139139
fit <- fits[[method]]
140140
draws <- fit$draws()
141-
expect_type(draws, "double")
141+
if (method == "generate_quantities") {
142+
expect_type(draws, "integer")
143+
} else {
144+
expect_type(draws, "double")
145+
}
142146
expect_s3_class(draws, "draws")
143147
}
144148
})

0 commit comments

Comments
 (0)