Skip to content

Add sig_figs argument #327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,16 @@ CmdStanArgs <- R6::R6Class(
init = NULL,
refresh = NULL,
output_dir = NULL,
validate_csv = TRUE) {
validate_csv = TRUE,
sig_figs = NULL) {

self$model_name <- model_name
self$exe_file <- exe_file
self$proc_ids <- proc_ids
self$data_file <- data_file
self$seed <- seed
self$refresh <- refresh
self$sig_figs <- sig_figs
self$method_args <- method_args
self$method <- self$method_args$method
self$save_latent_dynamics <- save_latent_dynamics
Expand Down Expand Up @@ -147,6 +149,10 @@ CmdStanArgs <- R6::R6Class(
args$output <- c(args$output, paste0("refresh=", self$refresh))
}

if (!is.null(self$sig_figs)) {
args$output <- c(args$output, paste0("sig_figs=", self$sig_figs))
}

args <- do.call(c, append(args, list(use.names = FALSE)))
self$method_args$compose(idx, args)
},
Expand Down Expand Up @@ -458,6 +464,10 @@ validate_cmdstan_args = function(self) {

checkmate::assert_flag(self$save_latent_dynamics)
checkmate::assert_integerish(self$refresh, lower = 0, null.ok = TRUE)
checkmate::assert_integerish(self$sig_figs, lower = 1, upper = 18, null.ok = TRUE)
if (!is.null(self$sig_figs) && cmdstan_version() < "2.25") {
warning("The 'sig_figs' argument is only supported with cmdstan 2.25+ and will be ignored!")
}
if (!is.null(self$refresh)) {
self$refresh <- as.integer(self$refresh)
}
Expand Down
29 changes: 20 additions & 9 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ CmdStanModel$set("public", name = "check_syntax", value = check_syntax_method)
#' term_buffer = NULL,
#' window = NULL,
#' fixed_param = FALSE,
#' sig_figs = NULL,
#' validate_csv = TRUE,
#' show_messages = TRUE
#' )
Expand Down Expand Up @@ -799,6 +800,7 @@ sample_method <- function(data = NULL,
term_buffer = NULL,
window = NULL,
fixed_param = FALSE,
sig_figs = NULL,
validate_csv = TRUE,
show_messages = TRUE,
# deprecated
Expand Down Expand Up @@ -869,7 +871,6 @@ sample_method <- function(data = NULL,
call. = FALSE)
}
}

sample_args <- SampleArgs$new(
iter_warmup = iter_warmup,
iter_sampling = iter_sampling,
Expand Down Expand Up @@ -898,7 +899,8 @@ sample_method <- function(data = NULL,
init = init,
refresh = refresh,
output_dir = output_dir,
validate_csv = validate_csv
validate_csv = validate_csv,
sig_figs = sig_figs
)
cmdstan_procs <- CmdStanMCMCProcs$new(
num_procs = chains,
Expand Down Expand Up @@ -943,7 +945,8 @@ CmdStanModel$set("public", name = "sample", value = sample_method)
#' output_dir = NULL,
#' algorithm = NULL,
#' init_alpha = NULL,
#' iter = NULL
#' iter = NULL,
#' sig_figs = NULL
#' )
#' ```
#'
Expand Down Expand Up @@ -975,7 +978,8 @@ optimize_method <- function(data = NULL,
output_dir = NULL,
algorithm = NULL,
init_alpha = NULL,
iter = NULL) {
iter = NULL,
sig_figs = NULL) {
optimize_args <- OptimizeArgs$new(
algorithm = algorithm,
init_alpha = init_alpha,
Expand All @@ -991,7 +995,8 @@ optimize_method <- function(data = NULL,
seed = seed,
init = init,
refresh = refresh,
output_dir = output_dir
output_dir = output_dir,
sig_figs = sig_figs
)

cmdstan_procs <- CmdStanProcs$new(num_procs = 1, show_stdout_messages = (is.null(refresh) || refresh != 0))
Expand Down Expand Up @@ -1038,7 +1043,8 @@ CmdStanModel$set("public", name = "optimize", value = optimize_method)
#' adapt_iter = NULL,
#' tol_rel_obj = NULL,
#' eval_elbo = NULL,
#' output_samples = NULL
#' output_samples = NULL,
#' sig_figs = NULL
#' )
#' ```
#'
Expand Down Expand Up @@ -1088,7 +1094,8 @@ variational_method <- function(data = NULL,
adapt_iter = NULL,
tol_rel_obj = NULL,
eval_elbo = NULL,
output_samples = NULL) {
output_samples = NULL,
sig_figs = NULL) {
variational_args <- VariationalArgs$new(
algorithm = algorithm,
iter = iter,
Expand All @@ -1111,7 +1118,8 @@ variational_method <- function(data = NULL,
seed = seed,
init = init,
refresh = refresh,
output_dir = output_dir
output_dir = output_dir,
sig_figs = sig_figs
)

cmdstan_procs <- CmdStanProcs$new(num_procs = 1, show_stdout_messages = (is.null(refresh) || refresh != 0))
Expand All @@ -1138,6 +1146,7 @@ CmdStanModel$set("public", name = "variational", value = variational_method)
#' data = NULL,
#' seed = NULL,
#' output_dir = NULL,
#' sig_figs = NULL,
#' parallel_chains = getOption("mc.cores", 1),
#' threads_per_chain = NULL
#' )
Expand Down Expand Up @@ -1204,6 +1213,7 @@ generate_quantities_method <- function(fitted_params,
data = NULL,
seed = NULL,
output_dir = NULL,
sig_figs = NULL,
parallel_chains = getOption("mc.cores", 1),
threads_per_chain = NULL) {
checkmate::assert_integerish(parallel_chains, lower = 1, null.ok = TRUE)
Expand All @@ -1219,7 +1229,8 @@ generate_quantities_method <- function(fitted_params,
proc_ids = seq_len(chains),
data_file = process_data(data),
seed = seed,
output_dir = output_dir
output_dir = output_dir,
sig_figs = sig_figs
)
cmdstan_procs <- CmdStanGQProcs$new(
num_procs = chains,
Expand Down
8 changes: 6 additions & 2 deletions man-roxygen/model-common-args.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
#' divergences for HMC). To save the temporary files created when
#' `save_latent_dynamics=TRUE` see the
#' [`$save_latent_dynamics_files()`][fit-method-save_latent_dynamics_files] method.
#'
#' * `output_dir`: (string) A path to a directory where CmdStan should write
#' its output CSV files. For interactive use this can typically be left at
#' `NULL` (temporary directory) since CmdStanR makes the CmdStan output (e.g.,
Expand All @@ -51,4 +50,9 @@
#' - If a path, then the files are created in `output_dir` with names
#' corresponding the defaults used by `$save_output_files()` (and similar
#' methods like `$save_latent_dynamics_files()`).
#'
#' * `sig_figs`: (positive integer) The number of significant figures used
#' for the output values. By default, CmdStan represent the output values with
#' 6 significant figures. The upper limit for `sig_figs` is 18. Increasing
#' this value can cause an increased usage of disk space due to larger
#' output CSV files.
#'
1 change: 1 addition & 0 deletions man/model-method-generate-quantities.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion man/model-method-optimize.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions man/model-method-sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion man/model-method-variational.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

112 changes: 112 additions & 0 deletions tests/testthat/test-fit-shared.R
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,115 @@ test_that("no output with refresh = 0", {
output <- utils::capture.output(tmp <- mod$sample(data = data_list, refresh = 0, chains = 1))
expect_equal(length(output), 3)
})

test_that("sig_figs works with all methods", {
skip_on_cran()
m <- "data {
int<lower=0> N;
int<lower=0> K;
int<lower=0,upper=1> y[N];
matrix[N, K] X;
}
parameters {
real alpha;
vector[K] beta;
}
model {
target += normal_lpdf(alpha | 0, 1);
target += normal_lpdf(beta | 0, 1);
target += bernoulli_logit_glm_lpmf(y | X, alpha, beta);
}
generated quantities {
real p2 = 0.12;
real p5 = 0.12345;
real p9 = 0.123456789;
}"
mod <- cmdstan_model(write_stan_file(m))
utils::capture.output(
sample <- mod$sample(sig_figs = 2, refresh = 0, data = testing_data("logistic"))
)
expect_equal(
as.numeric(posterior::subset_draws(sample$draws(), variable = c("p2","p5", "p9"), iteration = 1, chain = 1)),
c(0.12, 0.12, 0.12)
)
utils::capture.output(
sample <- mod$sample(sig_figs = 5, refresh = 0, data = testing_data("logistic"))
)
expect_equal(
as.numeric(posterior::subset_draws(sample$draws(), variable = c("p2","p5", "p9"), iteration = 1, chain = 1)),
c(0.12, 0.12345, 0.12346)
)
utils::capture.output(
sample <- mod$sample(sig_figs = 10, refresh = 0, data = testing_data("logistic"))
)
expect_equal(
as.numeric(posterior::subset_draws(sample$draws(), variable = c("p2","p5", "p9"), iteration = 1, chain = 1)),
c(0.12, 0.12345, 0.123456789)
)
utils::capture.output(
variational <- mod$variational(sig_figs = 2, refresh = 0, data = testing_data("logistic"))
)
expect_equal(
as.numeric(posterior::subset_draws(variational$draws(), variable = c("p2","p5", "p9"), iteration = 1, chain = 1)),
c(0.12, 0.12, 0.12)
)
utils::capture.output(
variational <- mod$variational(sig_figs = 5, refresh = 0, data = testing_data("logistic"))
)
expect_equal(
as.numeric(posterior::subset_draws(variational$draws(), variable = c("p2","p5", "p9"), iteration = 1, chain = 1)),
c(0.12, 0.12345, 0.12346)
)
utils::capture.output(
variational <- mod$variational(sig_figs = 10, refresh = 0, data = testing_data("logistic"))
)
expect_equal(
as.numeric(posterior::subset_draws(variational$draws(), variable = c("p2","p5", "p9"), iteration = 1, chain = 1)),
c(0.12, 0.12345, 0.123456789)
)
utils::capture.output(
gq <- mod$generate_quantities(fitted_params = sample, sig_figs = 2, data = testing_data("logistic"))
)
expect_equal(
as.numeric(posterior::subset_draws(gq$draws(), variable = c("p2","p5", "p9"), iteration = 1, chain = 1)),
c(0.12, 0.12, 0.12)
)
utils::capture.output(
gq <- mod$generate_quantities(fitted_params = sample, sig_figs = 5, data = testing_data("logistic"))
)
expect_equal(
as.numeric(posterior::subset_draws(gq$draws(), variable = c("p2","p5", "p9"), iteration = 1, chain = 1)),
c(0.12, 0.12345, 0.12346)
)
utils::capture.output(
gq <- mod$generate_quantities(fitted_params = sample, sig_figs = 10, data = testing_data("logistic"))
)
expect_equal(
as.numeric(posterior::subset_draws(gq$draws(), variable = c("p2","p5", "p9"), iteration = 1, chain = 1)),
c(0.12, 0.12345, 0.123456789)
)
utils::capture.output(
opt <- mod$optimize(sig_figs = 2, refresh = 0, data = testing_data("logistic"))
)
expect_equal(
as.numeric(opt$mle()[c("p2","p5", "p9")]),
c(0.12, 0.12, 0.12)
)
utils::capture.output(
opt <- mod$optimize(sig_figs = 5, refresh = 0, data = testing_data("logistic"))
)
expect_equal(
as.numeric(opt$mle()[c("p2","p5", "p9")]),
c(0.12, 0.12345, 0.12346)
)
utils::capture.output(
opt <- mod$optimize(sig_figs = 10, refresh = 0, data = testing_data("logistic"))
)
expect_equal(
as.numeric(opt$mle()[c("p2","p5", "p9")]),
c(0.12, 0.12345, 0.123456789)
)
})



1 change: 1 addition & 0 deletions tests/testthat/test-install.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ test_that("install_cmdstan() works with version and release_url", {
"version and release_url are supplied to install_cmdstan()"
)
expect_true(dir.exists(file.path(dir, "cmdstan-2.23.0")))
set_cmdstan_path(cmdstan_default_path())
})

test_that("toolchain checks on Unix work", {
Expand Down