Skip to content

Replace vroom with data.table::fread #318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Nov 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Imports:
posterior (>= 0.1.0),
processx,
R6 (>= 2.4.0),
vroom
data.table
Suggests:
bayesplot,
knitr,
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ specifying custom chain IDs. (#319)

* Added support for the `sig_figs` argument in CmdStan versions 2.25 and above. (#327)

* CSV reading is now faster by using `data.table::fread()`. (#318)

# cmdstanr 0.1.3

* New `$check_syntax()` method for CmdStanModel objects. (#276, #277)
Expand Down
19 changes: 11 additions & 8 deletions R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,22 @@ process_fitted_params <- function(fitted_params) {
paths <- file.path(tempdir(), paths)
chain <- 1
for (path in paths) {
chain_draws <- posterior::as_draws_df(posterior::subset_draws(draws, chain = chain))
colnames(chain_draws) <- unrepair_variable_names(variables)
chain_draws <- posterior::subset_draws(draws, chain = chain)
write(
paste0("# num_samples = ", iterations),
file = path
)
write(
paste0(unrepair_variable_names(variables), collapse = ","),
file = path,
append = FALSE
append = TRUE
)
vroom::vroom_write(
utils::write.table(
chain_draws,
delim = ",",
path = path,
col_names = TRUE,
progress = FALSE,
file = path,
sep = ",",
col.names = FALSE,
row.names = FALSE,
append = TRUE
)
chain <- chain + 1
Expand Down
201 changes: 80 additions & 121 deletions R/read_csv.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ read_cmdstan_csv <- function(files,
col_types <- NULL
col_select <- NULL
not_matching <- c()
vroom_warnings <- 0

for (output_file in files) {
if (is.null(metadata)) {
metadata <- read_csv_metadata(output_file)
Expand Down Expand Up @@ -205,48 +203,22 @@ read_cmdstan_csv <- function(files,
} else if (metadata$method == "optimize") {
all_draws <- 1
}

vroom_args <- list(
file = output_file,
comment = "#",
delim = ",",
trim_ws = TRUE,
altrep = FALSE,
progress = FALSE,
skip = metadata$lines_to_skip,
col_select = col_select,
num_threads = 1
)
if (metadata$method == "generate_quantities") {
# set the first arg as double to silence the type detection info
vroom_args$col_types <- list()
vroom_args$col_types[[col_select[1]]] <- "d"
} else {
vroom_args$col_types <- c("lp__" = "d")
vroom_args$n_max <- all_draws * 2
}

draws <- try(silent = TRUE, expr = {
suppressWarnings(do.call(vroom::vroom, vroom_args))
})
if (!inherits(draws, "try-error")) {
if (metadata$method != "generate_quantities") {
draws <- draws[!is.na(draws$lp__), ]
if (length(col_select) > 0) {
if (os_is_windows()) {
grep_path <- repair_path(Sys.which("grep.exe"))
fread_cmd <- paste0(grep_path, " -v '^#' ", output_file)
} else {
fread_cmd <- paste0("grep -v '^#' ", output_file)
}
} else {
if (vroom_warnings == 0) { # only warn the first time instead of for every csv file
warning(
"Fast CSV reading with vroom::vroom() failed. Using utils::read.csv() instead. ",
"\nTo help avoid this in the future, please report this issue at github.com/stan-dev/cmdstanr/issues ",
"and include the output from sessionInfo(). Thank you!",
call. = FALSE
suppressWarnings(
draws <- data.table::fread(
cmd = fread_cmd,
select = col_select
)
}
vroom_warnings <- vroom_warnings + 1
draws <- utils::read.csv(output_file, comment.char = "#", skip = metadata$lines_to_skip)
draws <- draws[, col_select]
)
} else {
draws <- NULL
}

if (nrow(draws) > 0) {
if (metadata$method == "sample") {
if (metadata$save_warmup == 1) {
Expand Down Expand Up @@ -316,7 +288,6 @@ read_cmdstan_csv <- function(files,
}
}
}

if (length(not_matching) > 0) {
not_matching_list <- paste(unique(not_matching), collapse = ", ")
warning("Supplied CSV files do not match in the following arguments: ",
Expand Down Expand Up @@ -411,102 +382,93 @@ read_sample_csv <- function(files,
#'
read_csv_metadata <- function(csv_file) {
checkmate::assert_file_exists(csv_file, access = "r", extension = "csv")
con <- file(csv_file, open = "r")
adaptation_terminated <- FALSE
param_names_read <- FALSE
inv_metric_next <- FALSE
inv_metric_diagonal_next <- FALSE
csv_file_info <- list()
csv_file_info[["inv_metric"]] <- NULL
inv_metric_rows <- 0
parsing_done <- FALSE
lines_before_param_names <- 0
while (length(line <- readLines(con, n = 1, warn = FALSE)) > 0 && !parsing_done) {
if (!startsWith(line, "#")) {
if (!param_names_read) {
param_names_read <- TRUE
all_names <- strsplit(line, ",")[[1]]
csv_file_info[["sampler_diagnostics"]] <- c()
csv_file_info[["model_params"]] <- c()
for (x in all_names) {
if (all(csv_file_info$algorithm != "fixed_param")) {
if (endsWith(x, "__") && !(x %in% c("lp__", "log_p__", "log_g__"))) {
csv_file_info[["sampler_diagnostics"]] <- c(csv_file_info[["sampler_diagnostics"]], x)
} else {
csv_file_info[["model_params"]] <- c(csv_file_info[["model_params"]], x)
}
} else {
if (!endsWith(x, "__")) {
csv_file_info[["model_params"]] <- c(csv_file_info[["model_params"]], x)
}
}
}
if (os_is_windows()) {
grep_path <- repair_path(Sys.which("grep.exe"))
fread_cmd <- paste0(grep_path, " '^[#a-zA-Z]' ", csv_file)
} else {
fread_cmd <- paste0("grep '^[#a-zA-Z]' ", csv_file)
}
suppressWarnings(
metadata <- data.table::fread(
cmd = fread_cmd,
colClasses = "character",
stringsAsFactors = FALSE,
fill = TRUE,
sep = "",
header= FALSE
)
)
if (is.null(metadata) || length(metadata) == 0) {
stop("Supplied CSV file is corrupt!", call. = FALSE)
}
for (line in metadata[[1]]) {
if (!startsWith(line, "#") && is.null(csv_file_info[["model_params"]])) {
# if no # at the start of line, the line is the CSV header
all_names <- strsplit(line, ",")[[1]]
if (all(csv_file_info$algorithm != "fixed_param")) {
csv_file_info[["sampler_diagnostics"]] <- all_names[endsWith(all_names, "__")]
csv_file_info[["sampler_diagnostics"]] <- csv_file_info[["sampler_diagnostics"]][!(csv_file_info[["sampler_diagnostics"]] %in% c("lp__", "log_p__", "log_g__"))]
csv_file_info[["model_params"]] <- all_names[!(all_names %in% csv_file_info[["sampler_diagnostics"]])]
} else {
csv_file_info[["model_params"]] <- all_names[!endsWith(all_names, "__")]
}
} else {
if (!param_names_read) {
lines_before_param_names <- lines_before_param_names + 1
}
if (!adaptation_terminated) {
if (regexpr("# Adaptation terminated", line, perl = TRUE) > 0) {
adaptation_terminated <- TRUE
} else {
tmp <- gsub("#", "", line, fixed = TRUE)
tmp <- gsub("(Default)", "", tmp, fixed = TRUE)
key_val <- grep("=", tmp, fixed = TRUE, value = TRUE)
key_val <- strsplit(key_val, split = "=", fixed = TRUE)
key_val <- rapply(key_val, trimws)
if (length(key_val) == 2) {
numeric_val <- suppressWarnings(as.numeric(key_val[2]))
if (!is.na(numeric_val)) {
csv_file_info[[key_val[1]]] <- numeric_val
} else {
if (nzchar(key_val[2])) {
csv_file_info[[key_val[1]]] <- key_val[2]
}
}
}
parse_key_val <- TRUE
if (regexpr("# Diagonal elements of inverse mass matrix:", line, perl = TRUE) > 0
|| regexpr("# Elements of inverse mass matrix:", line, perl = TRUE) > 0) {
inv_metric_next <- TRUE
parse_key_val <- FALSE
} else if (inv_metric_next) {
inv_metric_split <- strsplit(gsub("# ", "", line), ",")
if ((length(inv_metric_split) == 0) ||
((length(inv_metric_split) == 1) && identical(inv_metric_split[[1]], character(0))) ||
regexpr("[a-zA-z]", line, perl = TRUE) > 0 ||
inv_metric_split == "#") {
parsing_done <- TRUE
parse_key_val <- TRUE
break;
}
} else {
# after adaptation terminated read in the step size and inverse metrics
if (regexpr("# Step size = ", line, perl = TRUE) > 0) {
csv_file_info$stepsize_adaptation <- as.numeric(strsplit(line, " = ")[[1]][2])
} else if (regexpr("# Diagonal elements of inverse mass matrix:", line, perl = TRUE) > 0) {
inv_metric_diagonal_next <- TRUE
} else if (regexpr("# Elements of inverse mass matrix:", line, perl = TRUE) > 0){
inv_metric_next <- TRUE
} else if (inv_metric_diagonal_next) {
inv_metric_split <- strsplit(gsub("# ", "", line), ",")
if ((length(inv_metric_split) == 0) ||
((length(inv_metric_split) == 1) && identical(inv_metric_split[[1]], character(0)))) {
break;
}
if (inv_metric_rows == 0) {
csv_file_info$inv_metric <- rapply(inv_metric_split, as.numeric)
parsing_done <- TRUE
} else if (inv_metric_next) {
inv_metric_split <- strsplit(gsub("# ", "", line), ",")
if ((length(inv_metric_split) == 0) ||
((length(inv_metric_split) == 1) && identical(inv_metric_split[[1]], character(0)))) {
parsing_done <- TRUE
break;
}
if (inv_metric_rows == 0) {
csv_file_info$inv_metric <- rapply(inv_metric_split, as.numeric)
} else {
csv_file_info$inv_metric <- c(csv_file_info$inv_metric, rapply(inv_metric_split, as.numeric))
}
inv_metric_rows <- inv_metric_rows + 1
parse_key_val <- FALSE
}
if (parse_key_val) {
tmp <- gsub("#", "", line, fixed = TRUE)
tmp <- gsub("(Default)", "", tmp, fixed = TRUE)
key_val <- grep("=", tmp, fixed = TRUE, value = TRUE)
key_val <- strsplit(key_val, split = "=", fixed = TRUE)
key_val <- rapply(key_val, trimws)
if (any(key_val[1] == "Step size")) {
key_val[1] <- "step_size_adaptation"
}
if (length(key_val) == 2) {
numeric_val <- suppressWarnings(as.numeric(key_val[2]))
if (!is.na(numeric_val)) {
csv_file_info[[key_val[1]]] <- numeric_val
} else {
csv_file_info$inv_metric <- c(csv_file_info$inv_metric, rapply(inv_metric_split, as.numeric))
if (nzchar(key_val[2])) {
csv_file_info[[key_val[1]]] <- key_val[2]
}
}
inv_metric_rows <- inv_metric_rows + 1
}
}
}
}
close(con)
if (is.null(csv_file_info$method)) {
stop("Supplied CSV file is corrupt!", call. = FALSE)
}
if (length(csv_file_info$sampler_diagnostics) == 0 && length(csv_file_info$model_params) == 0) {
stop("Supplied CSV file does not contain any variable names or data!", call. = FALSE)
}
if (inv_metric_rows > 0) {
if (inv_metric_rows > 0 && csv_file_info$metric == "dense_e") {
rows <- inv_metric_rows
cols <- length(csv_file_info$inv_metric)/inv_metric_rows
dim(csv_file_info$inv_metric) <- c(rows,cols)
Expand All @@ -518,7 +480,6 @@ read_csv_metadata <- function(csv_file) {
csv_file_info$adapt_delta <- csv_file_info$delta
csv_file_info$max_treedepth <- csv_file_info$max_depth
csv_file_info$step_size <- csv_file_info$stepsize
csv_file_info$step_size_adaptation <- csv_file_info$stepsize_adaptation
csv_file_info$iter_warmup <- csv_file_info$num_warmup
csv_file_info$iter_sampling <- csv_file_info$num_samples
csv_file_info$threads_per_chain <- csv_file_info$num_threads
Expand All @@ -527,14 +488,12 @@ read_csv_metadata <- function(csv_file) {
csv_file_info$delta <- NULL
csv_file_info$max_depth <- NULL
csv_file_info$stepsize <- NULL
csv_file_info$stepsize_adaptation <- NULL
csv_file_info$num_warmup <- NULL
csv_file_info$num_samples <- NULL
csv_file_info$file <- NULL
csv_file_info$diagnostic_file <- NULL
csv_file_info$metric_file <- NULL
csv_file_info$num_threads <- NULL
csv_file_info$lines_to_skip <- lines_before_param_names

csv_file_info
}
Expand Down
17 changes: 9 additions & 8 deletions tests/testthat/test-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,16 @@ test_that("process_fitted_params() works if output_files in fit do not exist", {
expect_true(all(file.exists(new_files)))
chain <- 1
for(file in new_files) {
if (os_is_windows()) {
grep_path <- repair_path(Sys.which("grep.exe"))
fread_cmd <- paste0(grep_path, " -v '^#' ", file)
} else {
fread_cmd <- paste0("grep -v '^#' ", file)
}
suppressWarnings(
tmp_file_gq <- vroom::vroom(
file,
comment = "#",
delim = ',',
trim_ws = TRUE,
altrep = FALSE,
progress = FALSE,
skip = 1)
tmp_file_gq <- data.table::fread(
cmd = fread_cmd
)
)
tmp_file_gq <- posterior::as_draws_array(tmp_file_gq)
expect_equal(
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-fit-gq.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ test_that("draws() method returns draws_array (reading csv works)", {
draws_sum_y <- fit_gq$draws(variables = c("sum_y", "y_rep"))
draws_y_sum <- fit_gq$draws(variables = c("y_rep", "sum_y"))
draws_all_after <- fit_gq$draws()
expect_type(draws, "double")
expect_type(draws, "integer")
expect_s3_class(draws, "draws_array")
expect_equal(posterior::variables(draws), PARAM_NAMES)
expect_equal(posterior::nchains(draws), fit_gq$num_chains())
Expand Down
6 changes: 5 additions & 1 deletion tests/testthat/test-fit-shared.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ test_that("draws() method returns a 'draws' object", {
for (method in all_methods) {
fit <- fits[[method]]
draws <- fit$draws()
expect_type(draws, "double")
if (method == "generate_quantities") {
expect_type(draws, "integer")
} else {
expect_type(draws, "double")
}
expect_s3_class(draws, "draws")
}
})
Expand Down