Skip to content

Commit 1b77cf4

Browse files
authored
Merge pull request #886 from stan-dev/unconstrain-draws
`$unconstrain_draws()` returns draws format
2 parents 2bec769 + a43a178 commit 1b77cf4

File tree

4 files changed

+49
-17
lines changed

4 files changed

+49
-17
lines changed

R/fit.R

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,8 @@ CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_var
542542
#' @param files (character vector) The paths to the CmdStan CSV files. These can
543543
#' be files generated by running CmdStanR or running CmdStan directly.
544544
#' @param draws A `posterior::draws_*` object.
545+
#' @param format (string) The format of the returned draws. Must be a valid
546+
#' format from the \pkg{posterior} package.
545547
#'
546548
#' @examples
547549
#' \dontrun{
@@ -562,7 +564,8 @@ CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_var
562564
#' [unconstrain_variables()], [unconstrain_draws()], [variable_skeleton()],
563565
#' [hessian()]
564566
#'
565-
unconstrain_draws <- function(files = NULL, draws = NULL) {
567+
unconstrain_draws <- function(files = NULL, draws = NULL,
568+
format = getOption("cmdstanr_draws_format", "draws_array")) {
566569
if (!is.null(files) || !is.null(draws)) {
567570
if (!is.null(files) && !is.null(draws)) {
568571
stop("Either a list of CSV files or a draws object can be passed, not both",
@@ -600,13 +603,16 @@ unconstrain_draws <- function(files = NULL, draws = NULL) {
600603
skeleton <- self$variable_skeleton(transformed_parameters = FALSE,
601604
generated_quantities = FALSE)
602605
par_columns <- !(names(draws) %in% c(".chain", ".iteration", ".draw"))
603-
unconstrained <- lapply(split(draws, f = draws$.chain), function(chain) {
604-
lapply(asplit(chain, 1), function(draw) {
605-
par_list <- utils::relist(as.numeric(draw[par_columns]), skeleton)
606-
self$unconstrain_variables(variables = par_list)
607-
})
606+
meta_columns <- !par_columns
607+
unconstrained <- lapply(asplit(draws, 1), function(draw) {
608+
par_list <- utils::relist(as.numeric(draw[par_columns]), skeleton)
609+
self$unconstrain_variables(variables = par_list)
608610
})
609-
unconstrained
611+
612+
unconstrained <- do.call(rbind.data.frame, unconstrained)
613+
uncon_names <- private$model_methods_env_$unconstrained_param_names(private$model_methods_env_$model_ptr_, FALSE, FALSE)
614+
names(unconstrained) <- repair_variable_names(uncon_names)
615+
maybe_convert_draws_format(cbind.data.frame(unconstrained, draws[,meta_columns]), format)
610616
}
611617
CmdStanFit$set("public", name = "unconstrain_draws", value = unconstrain_draws)
612618

@@ -1546,7 +1552,7 @@ loo <- function(variables = "log_lik", r_eff = TRUE, moment_match = FALSE, ...)
15461552
loo = loo_result,
15471553
post_draws = function(x, ...) { x$draws(format = "draws_matrix") },
15481554
log_lik_i = log_lik_i,
1549-
unconstrain_pars = function(x, pars, ...) { do.call(rbind, lapply(x$unconstrain_draws(), function(chain) { do.call(rbind, chain) })) },
1555+
unconstrain_pars = function(x, pars, ...) { x$unconstrain_draws(format = "draws_matrix") },
15501556
log_prob_upars = function(x, upars, ...) { apply(upars, 1, x$log_prob) },
15511557
log_lik_i_upars = log_lik_i_upars,
15521558
...

inst/include/model_methods.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <Rcpp.h>
2+
#include <stan/model/model_base.hpp>
23
#include <stan/model/log_prob_grad.hpp>
34
#include <stan/model/log_prob_propto.hpp>
45
#include <boost/random/additive_combine.hpp>
@@ -115,3 +116,21 @@ std::vector<double> constrain_variables(SEXP ext_model_ptr, SEXP base_rng,
115116
ptr->write_array(*rng.get(), upars, params_i, vars, return_trans_pars, return_gen_quants);
116117
return vars;
117118
}
119+
120+
// [[Rcpp::export]]
121+
std::vector<std::string> unconstrained_param_names(SEXP ext_model_ptr, bool return_trans_pars, bool return_gen_quants) {
122+
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
123+
std::vector<std::string> rtn_names;
124+
ptr->unconstrained_param_names(rtn_names, return_trans_pars, return_gen_quants);
125+
return rtn_names;
126+
}
127+
128+
// [[Rcpp::export]]
129+
std::vector<std::string> constrained_param_names(SEXP ext_model_ptr,
130+
bool return_trans_pars,
131+
bool return_gen_quants) {
132+
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
133+
std::vector<std::string> rtn_names;
134+
ptr->constrained_param_names(rtn_names, return_trans_pars, return_gen_quants);
135+
return rtn_names;
136+
}

man/fit-method-unconstrain_draws.Rd

Lines changed: 8 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-model-methods.R

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -221,20 +221,20 @@ test_that("unconstrain_draws returns correct values", {
221221
mod <- cmdstan_model(write_stan_file(model_code),
222222
compile_model_methods = TRUE,
223223
force_recompile = TRUE)
224-
fit <- mod$sample(data = list(N = 0), chains = 1)
224+
fit <- mod$sample(data = list(N = 0), chains = 2)
225225

226226
x_draws <- fit$draws(format = "draws_df")$x
227227

228228
# Unconstrain all internal draws
229-
unconstrained_internal_draws <- fit$unconstrain_draws()[[1]]
229+
unconstrained_internal_draws <- fit$unconstrain_draws()
230230
expect_equal(as.numeric(x_draws), as.numeric(unconstrained_internal_draws))
231231

232232
# Unconstrain external CmdStan CSV files
233-
unconstrained_csv <- fit$unconstrain_draws(files = fit$output_files())[[1]]
233+
unconstrained_csv <- fit$unconstrain_draws(files = fit$output_files())
234234
expect_equal(as.numeric(x_draws), as.numeric(unconstrained_csv))
235235

236236
# Unconstrain existing draws object
237-
unconstrained_draws <- fit$unconstrain_draws(draws = fit$draws())[[1]]
237+
unconstrained_draws <- fit$unconstrain_draws(draws = fit$draws())
238238
expect_equal(as.numeric(x_draws), as.numeric(unconstrained_draws))
239239

240240
# With a lower-bounded constraint, the parameter draws should be the
@@ -253,19 +253,19 @@ test_that("unconstrain_draws returns correct values", {
253253
mod <- cmdstan_model(write_stan_file(model_code),
254254
compile_model_methods = TRUE,
255255
force_recompile = TRUE)
256-
fit <- mod$sample(data = list(N = 0), chains = 1)
256+
fit <- mod$sample(data = list(N = 0), chains = 2)
257257

258258
x_draws <- fit$draws(format = "draws_df")$x
259259

260-
unconstrained_internal_draws <- fit$unconstrain_draws()[[1]]
260+
unconstrained_internal_draws <- fit$unconstrain_draws()
261261
expect_equal(as.numeric(x_draws), exp(as.numeric(unconstrained_internal_draws)))
262262

263263
# Unconstrain external CmdStan CSV files
264-
unconstrained_csv <- fit$unconstrain_draws(files = fit$output_files())[[1]]
264+
unconstrained_csv <- fit$unconstrain_draws(files = fit$output_files())
265265
expect_equal(as.numeric(x_draws), exp(as.numeric(unconstrained_csv)))
266266

267267
# Unconstrain existing draws object
268-
unconstrained_draws <- fit$unconstrain_draws(draws = fit$draws())[[1]]
268+
unconstrained_draws <- fit$unconstrain_draws(draws = fit$draws())
269269
expect_equal(as.numeric(x_draws), exp(as.numeric(unconstrained_draws)))
270270
})
271271

0 commit comments

Comments
 (0)