From 856b221e37cc93b838e6c3bc2b0b030c408ee793 Mon Sep 17 00:00:00 2001 From: Rok Cesnovar Date: Sat, 5 Dec 2020 13:54:08 +0100 Subject: [PATCH 1/4] implement draws_to_csv used draws_to_csv in process_fitted_params added tests --- R/data.R | 131 +++++++++------- tests/testthat/test-data.R | 140 +++++++++++++++--- .../testthat/test-model-generate_quantities.R | 31 ++++ 3 files changed, 229 insertions(+), 73 deletions(-) diff --git a/R/data.R b/R/data.R index b760c2a57..7048e9259 100644 --- a/R/data.R +++ b/R/data.R @@ -123,69 +123,96 @@ any_na_elements <- function(data) { } +draws_to_csv <- function(draws, sampler_diagnostics = NULL) { + sampler_diagnostic_names <- c("accept_stat__", "stepsize__", "treedepth__", "n_leapfrog__", "divergent__", "energy__") + + n <- posterior::niterations(draws) + n_chains <- posterior::nchains(draws) + if (is.null(sampler_diagnostics)) { + # create dummy sampler diagnostics due to CmdStan requirement for all columns in GQ + sampler_diagnostics <- rep(0, n * length(sampler_diagnostic_names) * n_chains) + dim(sampler_diagnostics) <- c(n, n_chains, length(sampler_diagnostic_names)) + sampler_diagnostics <- posterior::as_draws_array(sampler_diagnostics) + posterior::variables(sampler_diagnostics) <- sampler_diagnostic_names + } + + # the columns must be in order "lp__, sampler_diagnostics, parameters" + variables <- posterior::variables(draws) + # create a dummy lp__ column if it does not exist + if ("lp__" %in% variables) { + lp__ <- NULL + } else { + lp__ <- rep(0, n * n_chains) + dim(lp__) <- c(n, n_chains, 1) + lp__ <- posterior::as_draws_array(lp__) + posterior::variables(lp__) <- "lp__" + } + variables <- c("lp__", sampler_diagnostic_names, variables[!(variables %in% c("lp__", "lp_approx__"))]) + draws <- posterior::subset_draws( + posterior::bind_draws(draws, sampler_diagnostics, lp__, along = "variable"), + variable = variables + ) + chains <- posterior::chain_ids(draws) + paths <- generate_file_names(basename = "fittedParams", ids = chains) + paths <- file.path(tempdir(), paths) + chain <- 1 + for (path in paths) { + write( + paste0("# num_samples = ", n, "\n", paste0(unrepair_variable_names(variables), collapse = ",")), + file = path, + append = FALSE + ) + utils::write.table( + posterior::subset_draws(draws, chain = chain), + sep = ",", + file = path, + col.names = FALSE, + row.names = FALSE, + append = TRUE + ) + chain <- chain + 1 + } + paths +} + #' Process fitted params for the generate quantities method #' #' @noRd -#' @param fitted_params Paths to CSV files compatible with CmdStan or a CmdStanMCMC object. +#' @param fitted_params Paths to CSV files produced by Cmdstan sampling, +#' a CmdStanMCMC or CmdStanVB object, a draws_array or draws_matrix. #' @return Paths to CSV files containing parameter values. #' process_fitted_params <- function(fitted_params) { if (is.character(fitted_params)) { paths <- absolute_path(fitted_params) - } else if (checkmate::test_r6(fitted_params, classes = ("CmdStanMCMC"))) { - if (all(file.exists(fitted_params$output_files()))) { + } else if (checkmate::test_r6(fitted_params, classes = "CmdStanMCMC") && + all(file.exists(fitted_params$output_files()))) { paths <- absolute_path(fitted_params$output_files()) - } else { - draws <- tryCatch(posterior::as_draws_array(fitted_params$draws()), - error=function(cond) { - stop("Unable to obtain draws from the fit (CmdStanMCMC) object.", call. = FALSE) - } - ) - sampler_diagnostics <- tryCatch(posterior::as_draws_array(fitted_params$sampler_diagnostics()), - error=function(cond) { - stop("Unable to obtain sampler diagnostics from the fit (CmdStanMCMC) object.", call. = FALSE) - } - ) - if (!is.null(draws)) { - variables <- posterior::variables(draws) - non_lp_variables <- variables[variables != "lp__"] - draws <- posterior::bind_draws( - posterior::subset_draws(draws, variable = "lp__"), - sampler_diagnostics, - posterior::subset_draws(draws, variable = non_lp_variables), - along = "variable" - ) - variables <- posterior::variables(draws) - chains <- posterior::chain_ids(draws) - iterations <- posterior::niterations(draws) - paths <- generate_file_names(basename = "fittedParams", ids = chains) - paths <- file.path(tempdir(), paths) - chain <- 1 - for (path in paths) { - chain_draws <- posterior::subset_draws(draws, chain = chain) - write( - paste0("# num_samples = ", iterations), - file = path - ) - write( - paste0(unrepair_variable_names(variables), collapse = ","), - file = path, - append = TRUE - ) - utils::write.table( - chain_draws, - file = path, - sep = ",", - col.names = FALSE, - row.names = FALSE, - append = TRUE - ) - chain <- chain + 1 - } + } else if(checkmate::test_r6(fitted_params, classes = c("CmdStanMCMC"))) { + draws <- tryCatch(fitted_params$draws(), + error=function(cond) { + stop("Unable to obtain draws from the fit object.", call. = FALSE) } - } + ) + sampler_diagnostics <- tryCatch(fitted_params$sampler_diagnostics(), + error=function(cond) { + NULL + } + ) + paths <- draws_to_csv(draws, sampler_diagnostics) + } else if(checkmate::test_r6(fitted_params, classes = c("CmdStanVB"))) { + draws <- tryCatch(fitted_params$draws(), + error=function(cond) { + stop("Unable to obtain draws from the fit object.", call. = FALSE) + } + ) + paths <- draws_to_csv(posterior::as_draws_array(draws)) + } else if (any(class(fitted_params) == "draws_array")){ + paths <- draws_to_csv(fitted_params) + } else if (any(class(fitted_params) == "draws_matrix")){ + paths <- draws_to_csv(posterior::as_draws_array(fitted_params)) } else { - stop("'fitted_params' should be a vector of paths or a CmdStanMCMC object.", call. = FALSE) + stop("'fitted_params' must be a list of paths to CSV files, a CmdStanMCMC/CmdStanVB object, a posterior::draws_array or a posterior::draws_matrix.", call. = FALSE) } paths } diff --git a/tests/testthat/test-data.R b/tests/testthat/test-data.R index de5c9320c..8fa0d3a0b 100644 --- a/tests/testthat/test-data.R +++ b/tests/testthat/test-data.R @@ -25,36 +25,21 @@ test_that("process_fitted_params() works with basic input types", { }) test_that("process_fitted_params() errors with bad args", { + error_msg <- "'fitted_params' must be a list of paths to CSV files, a CmdStanMCMC/CmdStanVB object, a posterior::draws_array or a posterior::draws_matrix." expect_error( process_fitted_params(5), - "'fitted_params' should be a vector of paths or a CmdStanMCMC object." + error_msg ) expect_error( process_fitted_params(NULL), - "'fitted_params' should be a vector of paths or a CmdStanMCMC object." - ) - expect_error( - process_fitted_params(fit_vb), - "'fitted_params' should be a vector of paths or a CmdStanMCMC object." + error_msg ) expect_error( process_fitted_params(fit_optimize), - "'fitted_params' should be a vector of paths or a CmdStanMCMC object." - ) - - fit_tmp <- testing_fit("bernoulli", method = "sample", seed = 123) - temp_file <- tempfile(fileext = ".rds") - saveRDS(fit_tmp, file = temp_file) - rm(fit_tmp) - gc() - fit_tmp_null <- readRDS(temp_file) - expect_error( - process_fitted_params(fit_tmp_null), - "Unable to obtain draws from the fit \\(CmdStanMCMC\\) object." + error_msg ) fit_tmp <- testing_fit("bernoulli", method = "sample", seed = 123) - fit_tmp$draws() temp_file <- tempfile(fileext = ".rds") saveRDS(fit_tmp, file = temp_file) rm(fit_tmp) @@ -62,11 +47,10 @@ test_that("process_fitted_params() errors with bad args", { fit_tmp_null <- readRDS(temp_file) expect_error( process_fitted_params(fit_tmp_null), - "Unable to obtain sampler diagnostics from the fit \\(CmdStanMCMC\\) object." + "Unable to obtain draws from the fit object." ) }) - test_that("process_fitted_params() works if output_files in fit do not exist", { fit_ref <- testing_fit("bernoulli", method = "sample", seed = 123) fit_tmp <- testing_fit("bernoulli", method = "sample", seed = 123) @@ -108,4 +92,118 @@ test_that("process_fitted_params() works if output_files in fit do not exist", { } }) +test_that("process_fitted_params() works with CmdStanMCMC", { + fit <- testing_fit("logistic", method = "sample", seed = 123) + fit_params_files <- process_fitted_params(fit) + expect_true(all(file.exists(fit_params_files))) + chain <- 1 + for(file in fit_params_files) { + if (os_is_windows()) { + grep_path <- repair_path(Sys.which("grep.exe")) + fread_cmd <- paste0(grep_path, " -v '^#' --color=never ", file) + } else { + fread_cmd <- paste0("grep -v '^#' --color=never ", file) + } + suppressWarnings( + fit_params_tmp <- data.table::fread( + cmd = fread_cmd + ) + ) + fit_params_tmp <- posterior::as_draws_array(fit_params_tmp) + posterior::variables(fit_params_tmp) <- repair_variable_names(posterior::variables(fit_params_tmp)) + expect_equal( + posterior::subset_draws(fit$draws(), variable = "lp__", chain = chain), + posterior::subset_draws(fit_params_tmp, variable = "lp__") + ) + expect_equal( + posterior::subset_draws(fit$draws(), variable = c("alpha", "beta[1]", "beta[2]", "beta[3]"), chain = chain), + posterior::subset_draws(fit_params_tmp, variable = c("alpha", "beta[1]", "beta[2]", "beta[3]"),) + ) + chain <- chain + 1 + } +}) +test_that("process_fitted_params() works with draws_array", { + fit <- testing_fit("logistic", method = "sample", seed = 123) + fit_params_files <- process_fitted_params(fit$draws()) + expect_true(all(file.exists(fit_params_files))) + chain <- 1 + for(file in fit_params_files) { + if (os_is_windows()) { + grep_path <- repair_path(Sys.which("grep.exe")) + fread_cmd <- paste0(grep_path, " -v '^#' --color=never ", file) + } else { + fread_cmd <- paste0("grep -v '^#' --color=never ", file) + } + suppressWarnings( + fit_params_tmp <- data.table::fread( + cmd = fread_cmd + ) + ) + fit_params_tmp <- posterior::as_draws_array(fit_params_tmp) + posterior::variables(fit_params_tmp) <- repair_variable_names(posterior::variables(fit_params_tmp)) + expect_equal( + posterior::subset_draws(fit$draws(), variable = "lp__", chain = chain), + posterior::subset_draws(fit_params_tmp, variable = "lp__") + ) + expect_equal( + posterior::subset_draws(fit$draws(), variable = c("alpha", "beta[1]", "beta[2]", "beta[3]"), chain = chain), + posterior::subset_draws(fit_params_tmp, variable = c("alpha", "beta[1]", "beta[2]", "beta[3]"),) + ) + chain <- chain + 1 + } +}) + +test_that("process_fitted_params() works with CmdStanVB", { + fit <- testing_fit("logistic", method = "variational", seed = 123) + file <- process_fitted_params(fit) + expect_true(file.exists(file)) + if (os_is_windows()) { + grep_path <- repair_path(Sys.which("grep.exe")) + fread_cmd <- paste0(grep_path, " -v '^#' --color=never ", file) + } else { + fread_cmd <- paste0("grep -v '^#' --color=never ", file) + } + suppressWarnings( + fit_params_tmp <- data.table::fread( + cmd = fread_cmd + ) + ) + fit_params_tmp <- posterior::as_draws_array(fit_params_tmp) + posterior::variables(fit_params_tmp) <- repair_variable_names(posterior::variables(fit_params_tmp)) + expect_equal( + posterior::subset_draws(posterior::as_draws_array(fit$draws()), variable = "lp__"), + posterior::subset_draws(fit_params_tmp, variable = "lp__") + ) + expect_equal( + posterior::subset_draws(posterior::as_draws_array(fit$draws()), variable = c("alpha", "beta[1]", "beta[2]", "beta[3]")), + posterior::subset_draws(fit_params_tmp, variable = c("alpha", "beta[1]", "beta[2]", "beta[3]")) + ) +}) + +test_that("process_fitted_params() works with draws_matrix", { + fit <- testing_fit("logistic", method = "variational", seed = 123) + file <- process_fitted_params(fit$draws()) + expect_true(file.exists(file)) + if (os_is_windows()) { + grep_path <- repair_path(Sys.which("grep.exe")) + fread_cmd <- paste0(grep_path, " -v '^#' --color=never ", file) + } else { + fread_cmd <- paste0("grep -v '^#' --color=never ", file) + } + suppressWarnings( + fit_params_tmp <- data.table::fread( + cmd = fread_cmd + ) + ) + fit_params_tmp <- posterior::as_draws_array(fit_params_tmp) + posterior::variables(fit_params_tmp) <- repair_variable_names(posterior::variables(fit_params_tmp)) + expect_equal( + posterior::subset_draws(posterior::as_draws_array(fit$draws()), variable = "lp__"), + posterior::subset_draws(fit_params_tmp, variable = "lp__") + ) + expect_equal( + posterior::subset_draws(posterior::as_draws_array(fit$draws()), variable = c("alpha", "beta[1]", "beta[2]", "beta[3]")), + posterior::subset_draws(fit_params_tmp, variable = c("alpha", "beta[1]", "beta[2]", "beta[3]")) + ) +}) diff --git a/tests/testthat/test-model-generate_quantities.R b/tests/testthat/test-model-generate_quantities.R index 5e59c4a79..ee2919bdb 100644 --- a/tests/testthat/test-model-generate_quantities.R +++ b/tests/testthat/test-model-generate_quantities.R @@ -70,3 +70,34 @@ test_that("generate_quantities work for different chains and parallel_chains", { fixed = TRUE ) }) + +test_that("generate_quantities works with draws_array", { + skip_on_cran() + fit_1_chain <- testing_fit("bernoulli", method = "sample", seed = 123, chains = 1) + expect_gq_output( + mod_gq$generate_quantities(data = data_list, fitted_params = fit_1_chain$draws()) + ) + expect_gq_output( + mod_gq$generate_quantities(data = data_list, fitted_params = fit$draws(), parallel_chains = 2) + ) + expect_gq_output( + mod_gq$generate_quantities(data = data_list, fitted_params = fit$draws(), parallel_chains = 4) + ) +}) + +fit <- testing_fit("bernoulli", method = "variational", seed = 123) +mod_gq <- testing_model("bernoulli_ppc") +data_list <- testing_data("bernoulli") +fit_gq <- mod_gq$generate_quantities(data = data_list, fitted_params = fit) + +test_that("generate_quantities works with VB and draws_matrix", { + skip_on_cran() + fit <- testing_fit("bernoulli", method = "variational", seed = 123) + fit_gq <- mod_gq$generate_quantities(data = data_list, fitted_params = fit) + expect_gq_output( + mod_gq$generate_quantities(data = data_list, fitted_params = fit) + ) + expect_gq_output( + mod_gq$generate_quantities(data = data_list, fitted_params = fit$draws()) + ) +}) From 0f8f9b8983eb0f3e0a937e3a56b1813b06f47750 Mon Sep 17 00:00:00 2001 From: jgabry Date: Tue, 8 Dec 2020 12:48:18 -0700 Subject: [PATCH 2/4] use posterior::draws_array() constructor for sampler_diagnostics and lp__ --- R/data.R | 42 ++++++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/R/data.R b/R/data.R index 7048e9259..30630d5e8 100644 --- a/R/data.R +++ b/R/data.R @@ -122,43 +122,49 @@ any_na_elements <- function(data) { any(has_na_elements) } - +#' Write posterior draws objects to csv files +#' @noRd +#' @param draws A `draws_array` from posterior pkg +#' @param sampler_diagnostics Either `NULL` or a `draws_array` of sampler diagnostics +#' @return Paths to CSV files (one per chain). +#' draws_to_csv <- function(draws, sampler_diagnostics = NULL) { - sampler_diagnostic_names <- c("accept_stat__", "stepsize__", "treedepth__", "n_leapfrog__", "divergent__", "energy__") - n <- posterior::niterations(draws) n_chains <- posterior::nchains(draws) + zeros <- rep(0, n * n_chains) # filler for creating dummy sampler diagnostics and lp__ if necessary if (is.null(sampler_diagnostics)) { # create dummy sampler diagnostics due to CmdStan requirement for all columns in GQ - sampler_diagnostics <- rep(0, n * length(sampler_diagnostic_names) * n_chains) - dim(sampler_diagnostics) <- c(n, n_chains, length(sampler_diagnostic_names)) - sampler_diagnostics <- posterior::as_draws_array(sampler_diagnostics) - posterior::variables(sampler_diagnostics) <- sampler_diagnostic_names + sampler_diagnostics <- posterior::draws_array( + accept_stat__ = zeros, + stepsize__ = zeros, + treedepth__ = zeros, + n_leapfrog__ = zeros, + divergent__ = zeros, + energy__ = zeros, + .nchains = n_chains + ) } # the columns must be in order "lp__, sampler_diagnostics, parameters" - variables <- posterior::variables(draws) - # create a dummy lp__ column if it does not exist - if ("lp__" %in% variables) { + draws_variables <- posterior::variables(draws) + if ("lp__" %in% draws_variables) { lp__ <- NULL - } else { - lp__ <- rep(0, n * n_chains) - dim(lp__) <- c(n, n_chains, 1) - lp__ <- posterior::as_draws_array(lp__) - posterior::variables(lp__) <- "lp__" + } else { # create a dummy lp__ if it does not exist + lp__ <- posterior::draws_array(lp__ = zeros, .nchains = n_chains) } - variables <- c("lp__", sampler_diagnostic_names, variables[!(variables %in% c("lp__", "lp_approx__"))]) + all_variables <- c("lp__", posterior::variables(sampler_diagnostics), draws_variables[!(draws_variables %in% c("lp__", "lp_approx__"))]) draws <- posterior::subset_draws( posterior::bind_draws(draws, sampler_diagnostics, lp__, along = "variable"), - variable = variables + variable = all_variables ) + chains <- posterior::chain_ids(draws) paths <- generate_file_names(basename = "fittedParams", ids = chains) paths <- file.path(tempdir(), paths) chain <- 1 for (path in paths) { write( - paste0("# num_samples = ", n, "\n", paste0(unrepair_variable_names(variables), collapse = ",")), + paste0("# num_samples = ", n, "\n", paste0(unrepair_variable_names(all_variables), collapse = ",")), file = path, append = FALSE ) From 759e5f96331c35e8d05edbc3e3313d96cc9aca76 Mon Sep 17 00:00:00 2001 From: jgabry Date: Tue, 8 Dec 2020 12:48:50 -0700 Subject: [PATCH 3/4] break up long error message line (doesn't matter, just easier to read) --- R/data.R | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/R/data.R b/R/data.R index 30630d5e8..3e38be3a6 100644 --- a/R/data.R +++ b/R/data.R @@ -191,7 +191,7 @@ draws_to_csv <- function(draws, sampler_diagnostics = NULL) { process_fitted_params <- function(fitted_params) { if (is.character(fitted_params)) { paths <- absolute_path(fitted_params) - } else if (checkmate::test_r6(fitted_params, classes = "CmdStanMCMC") && + } else if (checkmate::test_r6(fitted_params, classes = "CmdStanMCMC") && all(file.exists(fitted_params$output_files()))) { paths <- absolute_path(fitted_params$output_files()) } else if(checkmate::test_r6(fitted_params, classes = c("CmdStanMCMC"))) { @@ -218,7 +218,10 @@ process_fitted_params <- function(fitted_params) { } else if (any(class(fitted_params) == "draws_matrix")){ paths <- draws_to_csv(posterior::as_draws_array(fitted_params)) } else { - stop("'fitted_params' must be a list of paths to CSV files, a CmdStanMCMC/CmdStanVB object, a posterior::draws_array or a posterior::draws_matrix.", call. = FALSE) + stop( + "'fitted_params' must be a list of paths to CSV files, ", + "a CmdStanMCMC/CmdStanVB object, ", + "a posterior::draws_array or a posterior::draws_matrix.", call. = FALSE) } paths } From cf72fe3f9232bcc15a5f373bf2dbfd52fb81400e Mon Sep 17 00:00:00 2001 From: jgabry Date: Tue, 8 Dec 2020 12:55:32 -0700 Subject: [PATCH 4/4] update doc for fitted_params argument --- R/model.R | 7 ++++--- man/model-method-generate-quantities.Rd | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/R/model.R b/R/model.R index ec769ab63..85eb29e54 100644 --- a/R/model.R +++ b/R/model.R @@ -1430,9 +1430,10 @@ CmdStanModel$set("public", name = "variational", value = variational_method) #' #' @section Arguments: #' * `fitted_params`: (multiple options) The parameter draws to use. One of the following: -#' - A [CmdStanMCMC] fitted model object. -#' - A character vector of paths to CmdStan CSV output files containing -#' parameter draws. +#' - A [CmdStanMCMC] or [CmdStanVB] fitted model object. +#' - A [posterior::draws_array] (for MCMC) or [posterior::draws_matrix] (for VB) +#' object returned by CmdStanR's [`$draws()`][fit-method-draws] method. +#' - A character vector of paths to CmdStan CSV output files. #' * `data`, `seed`, `output_dir`, `parallel_chains`, `threads_per_chain`, `sig_figs`: #' Same as for the [`$sample()`][model-method-sample] method. #' diff --git a/man/model-method-generate-quantities.Rd b/man/model-method-generate-quantities.Rd index 8e86539cd..cac349bc8 100644 --- a/man/model-method-generate-quantities.Rd +++ b/man/model-method-generate-quantities.Rd @@ -27,9 +27,10 @@ based on previously fitted parameters. \itemize{ \item \code{fitted_params}: (multiple options) The parameter draws to use. One of the following: \itemize{ -\item A \link{CmdStanMCMC} fitted model object. -\item A character vector of paths to CmdStan CSV output files containing -parameter draws. +\item A \link{CmdStanMCMC} or \link{CmdStanVB} fitted model object. +\item A \link[posterior:draws_array]{posterior::draws_array} (for MCMC) or \link[posterior:draws_matrix]{posterior::draws_matrix} (for VB) +object returned by CmdStanR's \code{\link[=fit-method-draws]{$draws()}} method. +\item A character vector of paths to CmdStan CSV output files. } \item \code{data}, \code{seed}, \code{output_dir}, \code{parallel_chains}, \code{threads_per_chain}, \code{sig_figs}: Same as for the \code{\link[=model-method-sample]{$sample()}} method.