Skip to content

Commit 9008240

Browse files
authored
Merge pull request #390 from stan-dev/allow_draws_array_or_matrix_for_fittedparams
Accept cmdstanvb, draws_array or draws_matrix as fitted_params
2 parents 1eb251b + cf72fe3 commit 9008240

File tree

5 files changed

+246
-79
lines changed

5 files changed

+246
-79
lines changed

R/data.R

Lines changed: 88 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -122,70 +122,106 @@ any_na_elements <- function(data) {
122122
any(has_na_elements)
123123
}
124124

125+
#' Write posterior draws objects to csv files
126+
#' @noRd
127+
#' @param draws A `draws_array` from posterior pkg
128+
#' @param sampler_diagnostics Either `NULL` or a `draws_array` of sampler diagnostics
129+
#' @return Paths to CSV files (one per chain).
130+
#'
131+
draws_to_csv <- function(draws, sampler_diagnostics = NULL) {
132+
n <- posterior::niterations(draws)
133+
n_chains <- posterior::nchains(draws)
134+
zeros <- rep(0, n * n_chains) # filler for creating dummy sampler diagnostics and lp__ if necessary
135+
if (is.null(sampler_diagnostics)) {
136+
# create dummy sampler diagnostics due to CmdStan requirement for all columns in GQ
137+
sampler_diagnostics <- posterior::draws_array(
138+
accept_stat__ = zeros,
139+
stepsize__ = zeros,
140+
treedepth__ = zeros,
141+
n_leapfrog__ = zeros,
142+
divergent__ = zeros,
143+
energy__ = zeros,
144+
.nchains = n_chains
145+
)
146+
}
147+
148+
# the columns must be in order "lp__, sampler_diagnostics, parameters"
149+
draws_variables <- posterior::variables(draws)
150+
if ("lp__" %in% draws_variables) {
151+
lp__ <- NULL
152+
} else { # create a dummy lp__ if it does not exist
153+
lp__ <- posterior::draws_array(lp__ = zeros, .nchains = n_chains)
154+
}
155+
all_variables <- c("lp__", posterior::variables(sampler_diagnostics), draws_variables[!(draws_variables %in% c("lp__", "lp_approx__"))])
156+
draws <- posterior::subset_draws(
157+
posterior::bind_draws(draws, sampler_diagnostics, lp__, along = "variable"),
158+
variable = all_variables
159+
)
160+
161+
chains <- posterior::chain_ids(draws)
162+
paths <- generate_file_names(basename = "fittedParams", ids = chains)
163+
paths <- file.path(tempdir(), paths)
164+
chain <- 1
165+
for (path in paths) {
166+
write(
167+
paste0("# num_samples = ", n, "\n", paste0(unrepair_variable_names(all_variables), collapse = ",")),
168+
file = path,
169+
append = FALSE
170+
)
171+
utils::write.table(
172+
posterior::subset_draws(draws, chain = chain),
173+
sep = ",",
174+
file = path,
175+
col.names = FALSE,
176+
row.names = FALSE,
177+
append = TRUE
178+
)
179+
chain <- chain + 1
180+
}
181+
paths
182+
}
125183

126184
#' Process fitted params for the generate quantities method
127185
#'
128186
#' @noRd
129-
#' @param fitted_params Paths to CSV files compatible with CmdStan or a CmdStanMCMC object.
187+
#' @param fitted_params Paths to CSV files produced by Cmdstan sampling,
188+
#' a CmdStanMCMC or CmdStanVB object, a draws_array or draws_matrix.
130189
#' @return Paths to CSV files containing parameter values.
131190
#'
132191
process_fitted_params <- function(fitted_params) {
133192
if (is.character(fitted_params)) {
134193
paths <- absolute_path(fitted_params)
135-
} else if (checkmate::test_r6(fitted_params, classes = ("CmdStanMCMC"))) {
136-
if (all(file.exists(fitted_params$output_files()))) {
194+
} else if (checkmate::test_r6(fitted_params, classes = "CmdStanMCMC") &&
195+
all(file.exists(fitted_params$output_files()))) {
137196
paths <- absolute_path(fitted_params$output_files())
138-
} else {
139-
draws <- tryCatch(posterior::as_draws_array(fitted_params$draws()),
140-
error=function(cond) {
141-
stop("Unable to obtain draws from the fit (CmdStanMCMC) object.", call. = FALSE)
142-
}
143-
)
144-
sampler_diagnostics <- tryCatch(posterior::as_draws_array(fitted_params$sampler_diagnostics()),
145-
error=function(cond) {
146-
stop("Unable to obtain sampler diagnostics from the fit (CmdStanMCMC) object.", call. = FALSE)
147-
}
148-
)
149-
if (!is.null(draws)) {
150-
variables <- posterior::variables(draws)
151-
non_lp_variables <- variables[variables != "lp__"]
152-
draws <- posterior::bind_draws(
153-
posterior::subset_draws(draws, variable = "lp__"),
154-
sampler_diagnostics,
155-
posterior::subset_draws(draws, variable = non_lp_variables),
156-
along = "variable"
157-
)
158-
variables <- posterior::variables(draws)
159-
chains <- posterior::chain_ids(draws)
160-
iterations <- posterior::niterations(draws)
161-
paths <- generate_file_names(basename = "fittedParams", ids = chains)
162-
paths <- file.path(tempdir(), paths)
163-
chain <- 1
164-
for (path in paths) {
165-
chain_draws <- posterior::subset_draws(draws, chain = chain)
166-
write(
167-
paste0("# num_samples = ", iterations),
168-
file = path
169-
)
170-
write(
171-
paste0(unrepair_variable_names(variables), collapse = ","),
172-
file = path,
173-
append = TRUE
174-
)
175-
utils::write.table(
176-
chain_draws,
177-
file = path,
178-
sep = ",",
179-
col.names = FALSE,
180-
row.names = FALSE,
181-
append = TRUE
182-
)
183-
chain <- chain + 1
184-
}
197+
} else if(checkmate::test_r6(fitted_params, classes = c("CmdStanMCMC"))) {
198+
draws <- tryCatch(fitted_params$draws(),
199+
error=function(cond) {
200+
stop("Unable to obtain draws from the fit object.", call. = FALSE)
185201
}
186-
}
202+
)
203+
sampler_diagnostics <- tryCatch(fitted_params$sampler_diagnostics(),
204+
error=function(cond) {
205+
NULL
206+
}
207+
)
208+
paths <- draws_to_csv(draws, sampler_diagnostics)
209+
} else if(checkmate::test_r6(fitted_params, classes = c("CmdStanVB"))) {
210+
draws <- tryCatch(fitted_params$draws(),
211+
error=function(cond) {
212+
stop("Unable to obtain draws from the fit object.", call. = FALSE)
213+
}
214+
)
215+
paths <- draws_to_csv(posterior::as_draws_array(draws))
216+
} else if (any(class(fitted_params) == "draws_array")){
217+
paths <- draws_to_csv(fitted_params)
218+
} else if (any(class(fitted_params) == "draws_matrix")){
219+
paths <- draws_to_csv(posterior::as_draws_array(fitted_params))
187220
} else {
188-
stop("'fitted_params' should be a vector of paths or a CmdStanMCMC object.", call. = FALSE)
221+
stop(
222+
"'fitted_params' must be a list of paths to CSV files, ",
223+
"a CmdStanMCMC/CmdStanVB object, ",
224+
"a posterior::draws_array or a posterior::draws_matrix.", call. = FALSE)
189225
}
190226
paths
191227
}

R/model.R

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,9 +1430,10 @@ CmdStanModel$set("public", name = "variational", value = variational_method)
14301430
#'
14311431
#' @section Arguments:
14321432
#' * `fitted_params`: (multiple options) The parameter draws to use. One of the following:
1433-
#' - A [CmdStanMCMC] fitted model object.
1434-
#' - A character vector of paths to CmdStan CSV output files containing
1435-
#' parameter draws.
1433+
#' - A [CmdStanMCMC] or [CmdStanVB] fitted model object.
1434+
#' - A [posterior::draws_array] (for MCMC) or [posterior::draws_matrix] (for VB)
1435+
#' object returned by CmdStanR's [`$draws()`][fit-method-draws] method.
1436+
#' - A character vector of paths to CmdStan CSV output files.
14361437
#' * `data`, `seed`, `output_dir`, `parallel_chains`, `threads_per_chain`, `sig_figs`:
14371438
#' Same as for the [`$sample()`][model-method-sample] method.
14381439
#'

man/model-method-generate-quantities.Rd

Lines changed: 4 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-data.R

Lines changed: 119 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,48 +25,32 @@ test_that("process_fitted_params() works with basic input types", {
2525
})
2626

2727
test_that("process_fitted_params() errors with bad args", {
28+
error_msg <- "'fitted_params' must be a list of paths to CSV files, a CmdStanMCMC/CmdStanVB object, a posterior::draws_array or a posterior::draws_matrix."
2829
expect_error(
2930
process_fitted_params(5),
30-
"'fitted_params' should be a vector of paths or a CmdStanMCMC object."
31+
error_msg
3132
)
3233
expect_error(
3334
process_fitted_params(NULL),
34-
"'fitted_params' should be a vector of paths or a CmdStanMCMC object."
35-
)
36-
expect_error(
37-
process_fitted_params(fit_vb),
38-
"'fitted_params' should be a vector of paths or a CmdStanMCMC object."
35+
error_msg
3936
)
4037
expect_error(
4138
process_fitted_params(fit_optimize),
42-
"'fitted_params' should be a vector of paths or a CmdStanMCMC object."
43-
)
44-
45-
fit_tmp <- testing_fit("bernoulli", method = "sample", seed = 123)
46-
temp_file <- tempfile(fileext = ".rds")
47-
saveRDS(fit_tmp, file = temp_file)
48-
rm(fit_tmp)
49-
gc()
50-
fit_tmp_null <- readRDS(temp_file)
51-
expect_error(
52-
process_fitted_params(fit_tmp_null),
53-
"Unable to obtain draws from the fit \\(CmdStanMCMC\\) object."
39+
error_msg
5440
)
5541

5642
fit_tmp <- testing_fit("bernoulli", method = "sample", seed = 123)
57-
fit_tmp$draws()
5843
temp_file <- tempfile(fileext = ".rds")
5944
saveRDS(fit_tmp, file = temp_file)
6045
rm(fit_tmp)
6146
gc()
6247
fit_tmp_null <- readRDS(temp_file)
6348
expect_error(
6449
process_fitted_params(fit_tmp_null),
65-
"Unable to obtain sampler diagnostics from the fit \\(CmdStanMCMC\\) object."
50+
"Unable to obtain draws from the fit object."
6651
)
6752
})
6853

69-
7054
test_that("process_fitted_params() works if output_files in fit do not exist", {
7155
fit_ref <- testing_fit("bernoulli", method = "sample", seed = 123)
7256
fit_tmp <- testing_fit("bernoulli", method = "sample", seed = 123)
@@ -108,4 +92,118 @@ test_that("process_fitted_params() works if output_files in fit do not exist", {
10892
}
10993
})
11094

95+
test_that("process_fitted_params() works with CmdStanMCMC", {
96+
fit <- testing_fit("logistic", method = "sample", seed = 123)
97+
fit_params_files <- process_fitted_params(fit)
98+
expect_true(all(file.exists(fit_params_files)))
99+
chain <- 1
100+
for(file in fit_params_files) {
101+
if (os_is_windows()) {
102+
grep_path <- repair_path(Sys.which("grep.exe"))
103+
fread_cmd <- paste0(grep_path, " -v '^#' --color=never ", file)
104+
} else {
105+
fread_cmd <- paste0("grep -v '^#' --color=never ", file)
106+
}
107+
suppressWarnings(
108+
fit_params_tmp <- data.table::fread(
109+
cmd = fread_cmd
110+
)
111+
)
112+
fit_params_tmp <- posterior::as_draws_array(fit_params_tmp)
113+
posterior::variables(fit_params_tmp) <- repair_variable_names(posterior::variables(fit_params_tmp))
114+
expect_equal(
115+
posterior::subset_draws(fit$draws(), variable = "lp__", chain = chain),
116+
posterior::subset_draws(fit_params_tmp, variable = "lp__")
117+
)
118+
expect_equal(
119+
posterior::subset_draws(fit$draws(), variable = c("alpha", "beta[1]", "beta[2]", "beta[3]"), chain = chain),
120+
posterior::subset_draws(fit_params_tmp, variable = c("alpha", "beta[1]", "beta[2]", "beta[3]"),)
121+
)
122+
chain <- chain + 1
123+
}
124+
})
111125

126+
test_that("process_fitted_params() works with draws_array", {
127+
fit <- testing_fit("logistic", method = "sample", seed = 123)
128+
fit_params_files <- process_fitted_params(fit$draws())
129+
expect_true(all(file.exists(fit_params_files)))
130+
chain <- 1
131+
for(file in fit_params_files) {
132+
if (os_is_windows()) {
133+
grep_path <- repair_path(Sys.which("grep.exe"))
134+
fread_cmd <- paste0(grep_path, " -v '^#' --color=never ", file)
135+
} else {
136+
fread_cmd <- paste0("grep -v '^#' --color=never ", file)
137+
}
138+
suppressWarnings(
139+
fit_params_tmp <- data.table::fread(
140+
cmd = fread_cmd
141+
)
142+
)
143+
fit_params_tmp <- posterior::as_draws_array(fit_params_tmp)
144+
posterior::variables(fit_params_tmp) <- repair_variable_names(posterior::variables(fit_params_tmp))
145+
expect_equal(
146+
posterior::subset_draws(fit$draws(), variable = "lp__", chain = chain),
147+
posterior::subset_draws(fit_params_tmp, variable = "lp__")
148+
)
149+
expect_equal(
150+
posterior::subset_draws(fit$draws(), variable = c("alpha", "beta[1]", "beta[2]", "beta[3]"), chain = chain),
151+
posterior::subset_draws(fit_params_tmp, variable = c("alpha", "beta[1]", "beta[2]", "beta[3]"),)
152+
)
153+
chain <- chain + 1
154+
}
155+
})
156+
157+
test_that("process_fitted_params() works with CmdStanVB", {
158+
fit <- testing_fit("logistic", method = "variational", seed = 123)
159+
file <- process_fitted_params(fit)
160+
expect_true(file.exists(file))
161+
if (os_is_windows()) {
162+
grep_path <- repair_path(Sys.which("grep.exe"))
163+
fread_cmd <- paste0(grep_path, " -v '^#' --color=never ", file)
164+
} else {
165+
fread_cmd <- paste0("grep -v '^#' --color=never ", file)
166+
}
167+
suppressWarnings(
168+
fit_params_tmp <- data.table::fread(
169+
cmd = fread_cmd
170+
)
171+
)
172+
fit_params_tmp <- posterior::as_draws_array(fit_params_tmp)
173+
posterior::variables(fit_params_tmp) <- repair_variable_names(posterior::variables(fit_params_tmp))
174+
expect_equal(
175+
posterior::subset_draws(posterior::as_draws_array(fit$draws()), variable = "lp__"),
176+
posterior::subset_draws(fit_params_tmp, variable = "lp__")
177+
)
178+
expect_equal(
179+
posterior::subset_draws(posterior::as_draws_array(fit$draws()), variable = c("alpha", "beta[1]", "beta[2]", "beta[3]")),
180+
posterior::subset_draws(fit_params_tmp, variable = c("alpha", "beta[1]", "beta[2]", "beta[3]"))
181+
)
182+
})
183+
184+
test_that("process_fitted_params() works with draws_matrix", {
185+
fit <- testing_fit("logistic", method = "variational", seed = 123)
186+
file <- process_fitted_params(fit$draws())
187+
expect_true(file.exists(file))
188+
if (os_is_windows()) {
189+
grep_path <- repair_path(Sys.which("grep.exe"))
190+
fread_cmd <- paste0(grep_path, " -v '^#' --color=never ", file)
191+
} else {
192+
fread_cmd <- paste0("grep -v '^#' --color=never ", file)
193+
}
194+
suppressWarnings(
195+
fit_params_tmp <- data.table::fread(
196+
cmd = fread_cmd
197+
)
198+
)
199+
fit_params_tmp <- posterior::as_draws_array(fit_params_tmp)
200+
posterior::variables(fit_params_tmp) <- repair_variable_names(posterior::variables(fit_params_tmp))
201+
expect_equal(
202+
posterior::subset_draws(posterior::as_draws_array(fit$draws()), variable = "lp__"),
203+
posterior::subset_draws(fit_params_tmp, variable = "lp__")
204+
)
205+
expect_equal(
206+
posterior::subset_draws(posterior::as_draws_array(fit$draws()), variable = c("alpha", "beta[1]", "beta[2]", "beta[3]")),
207+
posterior::subset_draws(fit_params_tmp, variable = c("alpha", "beta[1]", "beta[2]", "beta[3]"))
208+
)
209+
})

tests/testthat/test-model-generate_quantities.R

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,34 @@ test_that("generate_quantities work for different chains and parallel_chains", {
7070
fixed = TRUE
7171
)
7272
})
73+
74+
test_that("generate_quantities works with draws_array", {
75+
skip_on_cran()
76+
fit_1_chain <- testing_fit("bernoulli", method = "sample", seed = 123, chains = 1)
77+
expect_gq_output(
78+
mod_gq$generate_quantities(data = data_list, fitted_params = fit_1_chain$draws())
79+
)
80+
expect_gq_output(
81+
mod_gq$generate_quantities(data = data_list, fitted_params = fit$draws(), parallel_chains = 2)
82+
)
83+
expect_gq_output(
84+
mod_gq$generate_quantities(data = data_list, fitted_params = fit$draws(), parallel_chains = 4)
85+
)
86+
})
87+
88+
fit <- testing_fit("bernoulli", method = "variational", seed = 123)
89+
mod_gq <- testing_model("bernoulli_ppc")
90+
data_list <- testing_data("bernoulli")
91+
fit_gq <- mod_gq$generate_quantities(data = data_list, fitted_params = fit)
92+
93+
test_that("generate_quantities works with VB and draws_matrix", {
94+
skip_on_cran()
95+
fit <- testing_fit("bernoulli", method = "variational", seed = 123)
96+
fit_gq <- mod_gq$generate_quantities(data = data_list, fitted_params = fit)
97+
expect_gq_output(
98+
mod_gq$generate_quantities(data = data_list, fitted_params = fit)
99+
)
100+
expect_gq_output(
101+
mod_gq$generate_quantities(data = data_list, fitted_params = fit$draws())
102+
)
103+
})

0 commit comments

Comments
 (0)