diff --git a/R/args.R b/R/args.R index 9690247f1..1f6b9a1cb 100644 --- a/R/args.R +++ b/R/args.R @@ -35,7 +35,8 @@ CmdStanArgs <- R6::R6Class( init = NULL, refresh = NULL, output_dir = NULL, - validate_csv = TRUE) { + validate_csv = TRUE, + sig_figs = NULL) { self$model_name <- model_name self$exe_file <- exe_file @@ -43,6 +44,7 @@ CmdStanArgs <- R6::R6Class( self$data_file <- data_file self$seed <- seed self$refresh <- refresh + self$sig_figs <- sig_figs self$method_args <- method_args self$method <- self$method_args$method self$save_latent_dynamics <- save_latent_dynamics @@ -147,6 +149,10 @@ CmdStanArgs <- R6::R6Class( args$output <- c(args$output, paste0("refresh=", self$refresh)) } + if (!is.null(self$sig_figs)) { + args$output <- c(args$output, paste0("sig_figs=", self$sig_figs)) + } + args <- do.call(c, append(args, list(use.names = FALSE))) self$method_args$compose(idx, args) }, @@ -458,6 +464,10 @@ validate_cmdstan_args = function(self) { checkmate::assert_flag(self$save_latent_dynamics) checkmate::assert_integerish(self$refresh, lower = 0, null.ok = TRUE) + checkmate::assert_integerish(self$sig_figs, lower = 1, upper = 18, null.ok = TRUE) + if (!is.null(self$sig_figs) && cmdstan_version() < "2.25") { + warning("The 'sig_figs' argument is only supported with cmdstan 2.25+ and will be ignored!") + } if (!is.null(self$refresh)) { self$refresh <- as.integer(self$refresh) } diff --git a/R/model.R b/R/model.R index ce8f7e2cc..7465ea806 100644 --- a/R/model.R +++ b/R/model.R @@ -657,6 +657,7 @@ CmdStanModel$set("public", name = "check_syntax", value = check_syntax_method) #' term_buffer = NULL, #' window = NULL, #' fixed_param = FALSE, +#' sig_figs = NULL, #' validate_csv = TRUE, #' show_messages = TRUE #' ) @@ -799,6 +800,7 @@ sample_method <- function(data = NULL, term_buffer = NULL, window = NULL, fixed_param = FALSE, + sig_figs = NULL, validate_csv = TRUE, show_messages = TRUE, # deprecated @@ -869,7 +871,6 @@ sample_method <- function(data = NULL, call. = FALSE) } } - sample_args <- SampleArgs$new( iter_warmup = iter_warmup, iter_sampling = iter_sampling, @@ -898,7 +899,8 @@ sample_method <- function(data = NULL, init = init, refresh = refresh, output_dir = output_dir, - validate_csv = validate_csv + validate_csv = validate_csv, + sig_figs = sig_figs ) cmdstan_procs <- CmdStanMCMCProcs$new( num_procs = chains, @@ -943,7 +945,8 @@ CmdStanModel$set("public", name = "sample", value = sample_method) #' output_dir = NULL, #' algorithm = NULL, #' init_alpha = NULL, -#' iter = NULL +#' iter = NULL, +#' sig_figs = NULL #' ) #' ``` #' @@ -975,7 +978,8 @@ optimize_method <- function(data = NULL, output_dir = NULL, algorithm = NULL, init_alpha = NULL, - iter = NULL) { + iter = NULL, + sig_figs = NULL) { optimize_args <- OptimizeArgs$new( algorithm = algorithm, init_alpha = init_alpha, @@ -991,7 +995,8 @@ optimize_method <- function(data = NULL, seed = seed, init = init, refresh = refresh, - output_dir = output_dir + output_dir = output_dir, + sig_figs = sig_figs ) cmdstan_procs <- CmdStanProcs$new(num_procs = 1, show_stdout_messages = (is.null(refresh) || refresh != 0)) @@ -1038,7 +1043,8 @@ CmdStanModel$set("public", name = "optimize", value = optimize_method) #' adapt_iter = NULL, #' tol_rel_obj = NULL, #' eval_elbo = NULL, -#' output_samples = NULL +#' output_samples = NULL, +#' sig_figs = NULL #' ) #' ``` #' @@ -1088,7 +1094,8 @@ variational_method <- function(data = NULL, adapt_iter = NULL, tol_rel_obj = NULL, eval_elbo = NULL, - output_samples = NULL) { + output_samples = NULL, + sig_figs = NULL) { variational_args <- VariationalArgs$new( algorithm = algorithm, iter = iter, @@ -1111,7 +1118,8 @@ variational_method <- function(data = NULL, seed = seed, init = init, refresh = refresh, - output_dir = output_dir + output_dir = output_dir, + sig_figs = sig_figs ) cmdstan_procs <- CmdStanProcs$new(num_procs = 1, show_stdout_messages = (is.null(refresh) || refresh != 0)) @@ -1138,6 +1146,7 @@ CmdStanModel$set("public", name = "variational", value = variational_method) #' data = NULL, #' seed = NULL, #' output_dir = NULL, +#' sig_figs = NULL, #' parallel_chains = getOption("mc.cores", 1), #' threads_per_chain = NULL #' ) @@ -1204,6 +1213,7 @@ generate_quantities_method <- function(fitted_params, data = NULL, seed = NULL, output_dir = NULL, + sig_figs = NULL, parallel_chains = getOption("mc.cores", 1), threads_per_chain = NULL) { checkmate::assert_integerish(parallel_chains, lower = 1, null.ok = TRUE) @@ -1219,7 +1229,8 @@ generate_quantities_method <- function(fitted_params, proc_ids = seq_len(chains), data_file = process_data(data), seed = seed, - output_dir = output_dir + output_dir = output_dir, + sig_figs = sig_figs ) cmdstan_procs <- CmdStanGQProcs$new( num_procs = chains, diff --git a/man-roxygen/model-common-args.R b/man-roxygen/model-common-args.R index 525fd194a..7ba913593 100644 --- a/man-roxygen/model-common-args.R +++ b/man-roxygen/model-common-args.R @@ -37,7 +37,6 @@ #' divergences for HMC). To save the temporary files created when #' `save_latent_dynamics=TRUE` see the #' [`$save_latent_dynamics_files()`][fit-method-save_latent_dynamics_files] method. -#' #' * `output_dir`: (string) A path to a directory where CmdStan should write #' its output CSV files. For interactive use this can typically be left at #' `NULL` (temporary directory) since CmdStanR makes the CmdStan output (e.g., @@ -51,4 +50,9 @@ #' - If a path, then the files are created in `output_dir` with names #' corresponding the defaults used by `$save_output_files()` (and similar #' methods like `$save_latent_dynamics_files()`). -#' +#' * `sig_figs`: (positive integer) The number of significant figures used +#' for the output values. By default, CmdStan represent the output values with +#' 6 significant figures. The upper limit for `sig_figs` is 18. Increasing +#' this value can cause an increased usage of disk space due to larger +#' output CSV files. +#' diff --git a/man/model-method-generate-quantities.Rd b/man/model-method-generate-quantities.Rd index 022921959..a158e046f 100644 --- a/man/model-method-generate-quantities.Rd +++ b/man/model-method-generate-quantities.Rd @@ -15,6 +15,7 @@ based on previously fitted parameters. data = NULL, seed = NULL, output_dir = NULL, + sig_figs = NULL, parallel_chains = getOption("mc.cores", 1), threads_per_chain = NULL ) diff --git a/man/model-method-optimize.Rd b/man/model-method-optimize.Rd index 001f36220..76b84edef 100644 --- a/man/model-method-optimize.Rd +++ b/man/model-method-optimize.Rd @@ -29,7 +29,8 @@ variables. Thus modes correspond to modes of the model as written. output_dir = NULL, algorithm = NULL, init_alpha = NULL, - iter = NULL + iter = NULL, + sig_figs = NULL ) } } @@ -94,6 +95,11 @@ files are removed when the fitted model object is garbage collected. corresponding the defaults used by \verb{$save_output_files()} (and similar methods like \verb{$save_latent_dynamics_files()}). } +\item \code{sig_figs}: (positive integer) The number of significant figures used +for the output values. By default, CmdStan represent the output values with +6 significant figures. The upper limit for \code{sig_figs} is 18. Increasing +this value can cause an increased usage of disk space due to larger +output CSV files. } } diff --git a/man/model-method-sample.Rd b/man/model-method-sample.Rd index 91aefb131..48f636c9a 100644 --- a/man/model-method-sample.Rd +++ b/man/model-method-sample.Rd @@ -37,6 +37,7 @@ some data. term_buffer = NULL, window = NULL, fixed_param = FALSE, + sig_figs = NULL, validate_csv = TRUE, show_messages = TRUE ) @@ -103,6 +104,11 @@ files are removed when the fitted model object is garbage collected. corresponding the defaults used by \verb{$save_output_files()} (and similar methods like \verb{$save_latent_dynamics_files()}). } +\item \code{sig_figs}: (positive integer) The number of significant figures used +for the output values. By default, CmdStan represent the output values with +6 significant figures. The upper limit for \code{sig_figs} is 18. Increasing +this value can cause an increased usage of disk space due to larger +output CSV files. } } diff --git a/man/model-method-variational.Rd b/man/model-method-variational.Rd index 45d193c46..e3ad66b84 100644 --- a/man/model-method-variational.Rd +++ b/man/model-method-variational.Rd @@ -35,7 +35,8 @@ matrix for the approximation. adapt_iter = NULL, tol_rel_obj = NULL, eval_elbo = NULL, - output_samples = NULL + output_samples = NULL, + sig_figs = NULL ) } } @@ -100,6 +101,11 @@ files are removed when the fitted model object is garbage collected. corresponding the defaults used by \verb{$save_output_files()} (and similar methods like \verb{$save_latent_dynamics_files()}). } +\item \code{sig_figs}: (positive integer) The number of significant figures used +for the output values. By default, CmdStan represent the output values with +6 significant figures. The upper limit for \code{sig_figs} is 18. Increasing +this value can cause an increased usage of disk space due to larger +output CSV files. } } diff --git a/tests/testthat/test-fit-shared.R b/tests/testthat/test-fit-shared.R index 6c75f42ef..ed2e8cb20 100644 --- a/tests/testthat/test-fit-shared.R +++ b/tests/testthat/test-fit-shared.R @@ -267,3 +267,115 @@ test_that("no output with refresh = 0", { output <- utils::capture.output(tmp <- mod$sample(data = data_list, refresh = 0, chains = 1)) expect_equal(length(output), 3) }) + +test_that("sig_figs works with all methods", { + skip_on_cran() + m <- "data { + int N; + int K; + int y[N]; + matrix[N, K] X; + } + parameters { + real alpha; + vector[K] beta; + } + model { + target += normal_lpdf(alpha | 0, 1); + target += normal_lpdf(beta | 0, 1); + target += bernoulli_logit_glm_lpmf(y | X, alpha, beta); + } + generated quantities { + real p2 = 0.12; + real p5 = 0.12345; + real p9 = 0.123456789; + }" + mod <- cmdstan_model(write_stan_file(m)) + utils::capture.output( + sample <- mod$sample(sig_figs = 2, refresh = 0, data = testing_data("logistic")) + ) + expect_equal( + as.numeric(posterior::subset_draws(sample$draws(), variable = c("p2","p5", "p9"), iteration = 1, chain = 1)), + c(0.12, 0.12, 0.12) + ) + utils::capture.output( + sample <- mod$sample(sig_figs = 5, refresh = 0, data = testing_data("logistic")) + ) + expect_equal( + as.numeric(posterior::subset_draws(sample$draws(), variable = c("p2","p5", "p9"), iteration = 1, chain = 1)), + c(0.12, 0.12345, 0.12346) + ) + utils::capture.output( + sample <- mod$sample(sig_figs = 10, refresh = 0, data = testing_data("logistic")) + ) + expect_equal( + as.numeric(posterior::subset_draws(sample$draws(), variable = c("p2","p5", "p9"), iteration = 1, chain = 1)), + c(0.12, 0.12345, 0.123456789) + ) + utils::capture.output( + variational <- mod$variational(sig_figs = 2, refresh = 0, data = testing_data("logistic")) + ) + expect_equal( + as.numeric(posterior::subset_draws(variational$draws(), variable = c("p2","p5", "p9"), iteration = 1, chain = 1)), + c(0.12, 0.12, 0.12) + ) + utils::capture.output( + variational <- mod$variational(sig_figs = 5, refresh = 0, data = testing_data("logistic")) + ) + expect_equal( + as.numeric(posterior::subset_draws(variational$draws(), variable = c("p2","p5", "p9"), iteration = 1, chain = 1)), + c(0.12, 0.12345, 0.12346) + ) + utils::capture.output( + variational <- mod$variational(sig_figs = 10, refresh = 0, data = testing_data("logistic")) + ) + expect_equal( + as.numeric(posterior::subset_draws(variational$draws(), variable = c("p2","p5", "p9"), iteration = 1, chain = 1)), + c(0.12, 0.12345, 0.123456789) + ) + utils::capture.output( + gq <- mod$generate_quantities(fitted_params = sample, sig_figs = 2, data = testing_data("logistic")) + ) + expect_equal( + as.numeric(posterior::subset_draws(gq$draws(), variable = c("p2","p5", "p9"), iteration = 1, chain = 1)), + c(0.12, 0.12, 0.12) + ) + utils::capture.output( + gq <- mod$generate_quantities(fitted_params = sample, sig_figs = 5, data = testing_data("logistic")) + ) + expect_equal( + as.numeric(posterior::subset_draws(gq$draws(), variable = c("p2","p5", "p9"), iteration = 1, chain = 1)), + c(0.12, 0.12345, 0.12346) + ) + utils::capture.output( + gq <- mod$generate_quantities(fitted_params = sample, sig_figs = 10, data = testing_data("logistic")) + ) + expect_equal( + as.numeric(posterior::subset_draws(gq$draws(), variable = c("p2","p5", "p9"), iteration = 1, chain = 1)), + c(0.12, 0.12345, 0.123456789) + ) + utils::capture.output( + opt <- mod$optimize(sig_figs = 2, refresh = 0, data = testing_data("logistic")) + ) + expect_equal( + as.numeric(opt$mle()[c("p2","p5", "p9")]), + c(0.12, 0.12, 0.12) + ) + utils::capture.output( + opt <- mod$optimize(sig_figs = 5, refresh = 0, data = testing_data("logistic")) + ) + expect_equal( + as.numeric(opt$mle()[c("p2","p5", "p9")]), + c(0.12, 0.12345, 0.12346) + ) + utils::capture.output( + opt <- mod$optimize(sig_figs = 10, refresh = 0, data = testing_data("logistic")) + ) + expect_equal( + as.numeric(opt$mle()[c("p2","p5", "p9")]), + c(0.12, 0.12345, 0.123456789) + ) +}) + + + diff --git a/tests/testthat/test-install.R b/tests/testthat/test-install.R index 5c6c946ab..5501a05a5 100644 --- a/tests/testthat/test-install.R +++ b/tests/testthat/test-install.R @@ -121,6 +121,7 @@ test_that("install_cmdstan() works with version and release_url", { "version and release_url are supplied to install_cmdstan()" ) expect_true(dir.exists(file.path(dir, "cmdstan-2.23.0"))) + set_cmdstan_path(cmdstan_default_path()) }) test_that("toolchain checks on Unix work", {