Skip to content

Replace vroom with data.table::fread #318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Nov 12, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Imports:
posterior (>= 0.1.0),
processx,
R6 (>= 2.4.0),
vroom
data.table
Suggests:
bayesplot,
knitr,
Expand Down
19 changes: 11 additions & 8 deletions R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -162,19 +162,22 @@ process_fitted_params <- function(fitted_params) {
paths <- file.path(tempdir(), paths)
chain <- 1
for (path in paths) {
chain_draws <- posterior::as_draws_df(posterior::subset_draws(draws, chain = chain))
colnames(chain_draws) <- unrepair_variable_names(variables)
chain_draws <- posterior::subset_draws(draws, chain = chain)
unname(chain_draws)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things:

  1. Why do we need to remove the names here? That's fine if necessary, just curious.

  2. Right now this isn't assigned to anything. I think you need

chain_draws <- unname(chain_draws)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to remove the names here? That's fine if necessary, just curious.

Because it otherwise writes the iteration ids in the CSV.

Right now this isn't assigned to anything. I think you need

Hm, this did help with that, will double-check.

write(
paste0("# num_samples = ", iterations),
file = path
)
write(
paste0(unrepair_variable_names(variables), collapse = ","),
file = path,
append = FALSE
append = TRUE
)
vroom::vroom_write(
utils::write.table(
chain_draws,
delim = ",",
path = path,
col_names = TRUE,
progress = FALSE,
file = path,
sep = ",",
col.names = FALSE,
append = TRUE
)
chain <- chain + 1
Expand Down
188 changes: 70 additions & 118 deletions R/read_csv.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ read_cmdstan_csv <- function(files,
col_types <- NULL
col_select <- NULL
not_matching <- c()
vroom_warnings <- 0

for (output_file in files) {
if (is.null(metadata)) {
metadata <- read_csv_metadata(output_file)
Expand Down Expand Up @@ -205,48 +203,12 @@ read_cmdstan_csv <- function(files,
} else if (metadata$method == "optimize") {
all_draws <- 1
}

vroom_args <- list(
file = output_file,
comment = "#",
delim = ",",
trim_ws = TRUE,
altrep = FALSE,
progress = FALSE,
skip = metadata$lines_to_skip,
col_select = col_select,
num_threads = 1
suppressWarnings(
draws <- data.table::fread(
cmd = paste0("grep -v '^#' ", output_file),
select = col_select
)
)
if (metadata$method == "generate_quantities") {
# set the first arg as double to silence the type detection info
vroom_args$col_types <- list()
vroom_args$col_types[[col_select[1]]] <- "d"
} else {
vroom_args$col_types <- c("lp__" = "d")
vroom_args$n_max <- all_draws * 2
}

draws <- try(silent = TRUE, expr = {
suppressWarnings(do.call(vroom::vroom, vroom_args))
})
if (!inherits(draws, "try-error")) {
if (metadata$method != "generate_quantities") {
draws <- draws[!is.na(draws$lp__), ]
}
} else {
if (vroom_warnings == 0) { # only warn the first time instead of for every csv file
warning(
"Fast CSV reading with vroom::vroom() failed. Using utils::read.csv() instead. ",
"\nTo help avoid this in the future, please report this issue at github.com/stan-dev/cmdstanr/issues ",
"and include the output from sessionInfo(). Thank you!",
call. = FALSE
)
}
vroom_warnings <- vroom_warnings + 1
draws <- utils::read.csv(output_file, comment.char = "#", skip = metadata$lines_to_skip)
draws <- draws[, col_select]
}

if (nrow(draws) > 0) {
if (metadata$method == "sample") {
if (metadata$save_warmup == 1) {
Expand Down Expand Up @@ -411,102 +373,95 @@ read_sample_csv <- function(files,
#'
read_csv_metadata <- function(csv_file) {
checkmate::assert_file_exists(csv_file, access = "r", extension = "csv")
con <- file(csv_file, open = "r")
adaptation_terminated <- FALSE
param_names_read <- FALSE
inv_metric_next <- FALSE
inv_metric_diagonal_next <- FALSE
csv_file_info <- list()
csv_file_info[["inv_metric"]] <- NULL
inv_metric_rows <- 0
parsing_done <- FALSE
lines_before_param_names <- 0
while (length(line <- readLines(con, n = 1, warn = FALSE)) > 0 && !parsing_done) {
suppressWarnings(
metadata <- data.table::fread(
cmd = paste0("grep '^[#a-zA-Z]' ", csv_file),
colClasses = "character",
stringsAsFactors = FALSE,
fill = TRUE,
sep = "",
header= FALSE
)
)
if (is.null(metadata) || length(metadata) == 0) {
stop("Supplied CSV file is corrupt!", call. = FALSE)
}
for (line in metadata[[1]]) {
if (!startsWith(line, "#")) {
if (!param_names_read) {
param_names_read <- TRUE
all_names <- strsplit(line, ",")[[1]]
csv_file_info[["sampler_diagnostics"]] <- c()
csv_file_info[["model_params"]] <- c()
for (x in all_names) {
if (all(csv_file_info$algorithm != "fixed_param")) {
if (endsWith(x, "__") && !(x %in% c("lp__", "log_p__", "log_g__"))) {
csv_file_info[["sampler_diagnostics"]] <- c(csv_file_info[["sampler_diagnostics"]], x)
} else {
csv_file_info[["model_params"]] <- c(csv_file_info[["model_params"]], x)
}
# if no # at the start of line, the line is the CSV header
all_names <- strsplit(line, ",")[[1]]
csv_file_info[["sampler_diagnostics"]] <- c()
csv_file_info[["model_params"]] <- c()
for (x in all_names) {
if (all(csv_file_info$algorithm != "fixed_param")) {
if (endsWith(x, "__") && !(x %in% c("lp__", "log_p__", "log_g__"))) {
csv_file_info[["sampler_diagnostics"]] <- c(csv_file_info[["sampler_diagnostics"]], x)
} else {
if (!endsWith(x, "__")) {
csv_file_info[["model_params"]] <- c(csv_file_info[["model_params"]], x)
}
csv_file_info[["model_params"]] <- c(csv_file_info[["model_params"]], x)
}
} else {
if (!endsWith(x, "__")) {
csv_file_info[["model_params"]] <- c(csv_file_info[["model_params"]], x)
}
}
}
} else {
if (!param_names_read) {
lines_before_param_names <- lines_before_param_names + 1
}
if (!adaptation_terminated) {
if (regexpr("# Adaptation terminated", line, perl = TRUE) > 0) {
adaptation_terminated <- TRUE
} else {
tmp <- gsub("#", "", line, fixed = TRUE)
tmp <- gsub("(Default)", "", tmp, fixed = TRUE)
key_val <- grep("=", tmp, fixed = TRUE, value = TRUE)
key_val <- strsplit(key_val, split = "=", fixed = TRUE)
key_val <- rapply(key_val, trimws)
if (length(key_val) == 2) {
numeric_val <- suppressWarnings(as.numeric(key_val[2]))
if (!is.na(numeric_val)) {
csv_file_info[[key_val[1]]] <- numeric_val
} else {
if (nzchar(key_val[2])) {
csv_file_info[[key_val[1]]] <- key_val[2]
}
}
}
parse_key_val <- TRUE
if (regexpr("# Diagonal elements of inverse mass matrix:", line, perl = TRUE) > 0
|| regexpr("# Elements of inverse mass matrix:", line, perl = TRUE) > 0) {
inv_metric_next <- TRUE
parse_key_val <- FALSE
} else if (inv_metric_next) {
inv_metric_split <- strsplit(gsub("# ", "", line), ",")
if ((length(inv_metric_split) == 0) ||
((length(inv_metric_split) == 1) && identical(inv_metric_split[[1]], character(0))) ||
regexpr("[a-zA-z]", line, perl = TRUE) > 0 ||
inv_metric_split == "#") {
parsing_done <- TRUE
parse_key_val <- TRUE
break;
}
} else {
# after adaptation terminated read in the step size and inverse metrics
if (regexpr("# Step size = ", line, perl = TRUE) > 0) {
csv_file_info$stepsize_adaptation <- as.numeric(strsplit(line, " = ")[[1]][2])
} else if (regexpr("# Diagonal elements of inverse mass matrix:", line, perl = TRUE) > 0) {
inv_metric_diagonal_next <- TRUE
} else if (regexpr("# Elements of inverse mass matrix:", line, perl = TRUE) > 0){
inv_metric_next <- TRUE
} else if (inv_metric_diagonal_next) {
inv_metric_split <- strsplit(gsub("# ", "", line), ",")
if ((length(inv_metric_split) == 0) ||
((length(inv_metric_split) == 1) && identical(inv_metric_split[[1]], character(0)))) {
break;
}
if (inv_metric_rows == 0) {
csv_file_info$inv_metric <- rapply(inv_metric_split, as.numeric)
parsing_done <- TRUE
} else if (inv_metric_next) {
inv_metric_split <- strsplit(gsub("# ", "", line), ",")
if ((length(inv_metric_split) == 0) ||
((length(inv_metric_split) == 1) && identical(inv_metric_split[[1]], character(0)))) {
parsing_done <- TRUE
break;
}
if (inv_metric_rows == 0) {
csv_file_info$inv_metric <- rapply(inv_metric_split, as.numeric)
} else {
csv_file_info$inv_metric <- c(csv_file_info$inv_metric, rapply(inv_metric_split, as.numeric))
}
inv_metric_rows <- inv_metric_rows + 1
parse_key_val <- FALSE
}
if (parse_key_val) {
tmp <- gsub("#", "", line, fixed = TRUE)
tmp <- gsub("(Default)", "", tmp, fixed = TRUE)
key_val <- grep("=", tmp, fixed = TRUE, value = TRUE)
key_val <- strsplit(key_val, split = "=", fixed = TRUE)
key_val <- rapply(key_val, trimws)
if (any(key_val[1] == "Step size")) {
key_val[1] <- "step_size_adaptation"
}
if (length(key_val) == 2) {
numeric_val <- suppressWarnings(as.numeric(key_val[2]))
if (!is.na(numeric_val)) {
csv_file_info[[key_val[1]]] <- numeric_val
} else {
csv_file_info$inv_metric <- c(csv_file_info$inv_metric, rapply(inv_metric_split, as.numeric))
if (nzchar(key_val[2])) {
csv_file_info[[key_val[1]]] <- key_val[2]
}
}
inv_metric_rows <- inv_metric_rows + 1
}
}
}
}
close(con)
if (is.null(csv_file_info$method)) {
stop("Supplied CSV file is corrupt!", call. = FALSE)
}
if (length(csv_file_info$sampler_diagnostics) == 0 && length(csv_file_info$model_params) == 0) {
stop("Supplied CSV file does not contain any variable names or data!", call. = FALSE)
}
if (inv_metric_rows > 0) {
if (inv_metric_rows > 0 && csv_file_info$metric == "dense_e") {
rows <- inv_metric_rows
cols <- length(csv_file_info$inv_metric)/inv_metric_rows
dim(csv_file_info$inv_metric) <- c(rows,cols)
Expand All @@ -518,7 +473,6 @@ read_csv_metadata <- function(csv_file) {
csv_file_info$adapt_delta <- csv_file_info$delta
csv_file_info$max_treedepth <- csv_file_info$max_depth
csv_file_info$step_size <- csv_file_info$stepsize
csv_file_info$step_size_adaptation <- csv_file_info$stepsize_adaptation
csv_file_info$iter_warmup <- csv_file_info$num_warmup
csv_file_info$iter_sampling <- csv_file_info$num_samples
csv_file_info$threads_per_chain <- csv_file_info$num_threads
Expand All @@ -527,14 +481,12 @@ read_csv_metadata <- function(csv_file) {
csv_file_info$delta <- NULL
csv_file_info$max_depth <- NULL
csv_file_info$stepsize <- NULL
csv_file_info$stepsize_adaptation <- NULL
csv_file_info$num_warmup <- NULL
csv_file_info$num_samples <- NULL
csv_file_info$file <- NULL
csv_file_info$diagnostic_file <- NULL
csv_file_info$metric_file <- NULL
csv_file_info$num_threads <- NULL
csv_file_info$lines_to_skip <- lines_before_param_names

csv_file_info
}
Expand Down
11 changes: 3 additions & 8 deletions tests/testthat/test-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,9 @@ test_that("process_fitted_params() works if output_files in fit do not exist", {
chain <- 1
for(file in new_files) {
suppressWarnings(
tmp_file_gq <- vroom::vroom(
file,
comment = "#",
delim = ',',
trim_ws = TRUE,
altrep = FALSE,
progress = FALSE,
skip = 1)
tmp_file_gq <- data.table::fread(
cmd = paste0("grep -v '^#' ", file)
)
)
tmp_file_gq <- posterior::as_draws_array(tmp_file_gq)
expect_equal(
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-fit-gq.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ test_that("draws() method returns draws_array (reading csv works)", {
draws_sum_y <- fit_gq$draws(variables = c("sum_y", "y_rep"))
draws_y_sum <- fit_gq$draws(variables = c("y_rep", "sum_y"))
draws_all_after <- fit_gq$draws()
expect_type(draws, "double")
expect_type(draws, "integer")
expect_s3_class(draws, "draws_array")
expect_equal(posterior::variables(draws), PARAM_NAMES)
expect_equal(posterior::nchains(draws), fit_gq$num_chains())
Expand Down
38 changes: 19 additions & 19 deletions tests/testthat/test-fit-mcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,25 +68,25 @@ test_that("draws() method returns draws_array (reading csv works)", {
expect_equal(posterior::variables(draws_beta_alpha), c("beta[1]", "beta[2]", "beta[3]", "alpha"))
})

test_that("inv_metric method works after mcmc", {
skip_on_cran()
x <- fit_mcmc_1$inv_metric()
expect_length(x, fit_mcmc_1$num_chains())
checkmate::expect_matrix(x[[1]])
checkmate::expect_matrix(x[[2]])
expect_equal(x[[1]], diag(diag(x[[1]])))

x <- fit_mcmc_1$inv_metric(matrix=FALSE)
expect_length(x, fit_mcmc_1$num_chains())
expect_null(dim(x[[1]]))
checkmate::expect_numeric(x[[1]])
checkmate::expect_numeric(x[[2]])

x <- fit_mcmc_2$inv_metric()
expect_length(x, fit_mcmc_2$num_chains())
checkmate::expect_matrix(x[[1]])
expect_false(x[[1]][1,2] == 0) # dense
})
# test_that("inv_metric method works after mcmc", {
# skip_on_cran()
# x <- fit_mcmc_1$inv_metric()
# expect_length(x, fit_mcmc_1$num_chains())
# checkmate::expect_matrix(x[[1]])
# checkmate::expect_matrix(x[[2]])
# expect_equal(x[[1]], diag(diag(x[[1]])))
#
# x <- fit_mcmc_1$inv_metric(matrix=FALSE)
# expect_length(x, fit_mcmc_1$num_chains())
# expect_null(dim(x[[1]]))
# checkmate::expect_numeric(x[[1]])
# checkmate::expect_numeric(x[[2]])
#
# x <- fit_mcmc_2$inv_metric()
# expect_length(x, fit_mcmc_2$num_chains())
# checkmate::expect_matrix(x[[1]])
# expect_false(x[[1]][1,2] == 0) # dense
# })
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are all these lines commented out on purpose?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test fails on macOS machine on the CI. I am currently unable to debug further. If you have a few minutes, mind running this test if it fails for you?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I can check in a few min and let you know

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test passes on my mac

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! Thats good news and bad news at the same time :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah good and bad. maybe we should let it run on CI again and I can see if I can debug it


test_that("summary() method works after mcmc", {
skip_on_cran()
Expand Down
8 changes: 7 additions & 1 deletion tests/testthat/test-fit-shared.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,16 @@ test_that("cmdstan_summary() and cmdstan_diagnose() work correctly", {

test_that("draws() method returns a 'draws' object", {
skip_on_cran()
types <- list(
"sample" = "double",
"optimize" = "double",
"variational" = "double",
"generate_quantities" = "integer"
)
for (method in all_methods) {
fit <- fits[[method]]
draws <- fit$draws()
expect_type(draws, "double")
expect_type(draws, types[[method]])
expect_s3_class(draws, "draws")
}
})
Expand Down