Skip to content

Commit d34b77e

Browse files
committed
Improve efficiency of model methods, tidy code
1 parent 15aa9d9 commit d34b77e

File tree

5 files changed

+63
-96
lines changed

5 files changed

+63
-96
lines changed

R/fit.R

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -518,12 +518,8 @@ unconstrain_variables <- function(variables) {
518518
" not provided!", call. = FALSE)
519519
}
520520

521-
# Remove zero-length parameters from model_variables, otherwise process_init
522-
# warns about missing inputs
523-
model_variables$parameters <- model_variables$parameters[nonzero_length_params]
524-
525-
stan_pars <- process_init(list(variables), num_procs = 1, model_variables)
526-
private$model_methods_env_$unconstrain_variables(private$model_methods_env_$model_ptr_, stan_pars)
521+
variables_vector <- unlist(variables, recursive = TRUE, use.names = FALSE)
522+
private$model_methods_env_$unconstrain_variables(private$model_methods_env_$model_ptr_, variables_vector)
527523
}
528524
CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_variables)
529525

@@ -571,11 +567,11 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
571567
call. = FALSE)
572568
}
573569
if (!is.null(files)) {
574-
read_csv <- read_cmdstan_csv(files = files, format = "draws_df")
570+
read_csv <- read_cmdstan_csv(files = files, format = "draws_matrix")
575571
draws <- read_csv$post_warmup_draws
576572
}
577573
if (!is.null(draws)) {
578-
draws <- maybe_convert_draws_format(draws, "draws_df")
574+
draws <- maybe_convert_draws_format(draws, "draws_matrix")
579575
}
580576
} else {
581577
if (is.null(private$draws_)) {
@@ -584,7 +580,7 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
584580
}
585581
private$read_csv_(format = "draws_df")
586582
}
587-
draws <- maybe_convert_draws_format(private$draws_, "draws_df")
583+
draws <- maybe_convert_draws_format(private$draws_, "draws_matrix")
588584
}
589585

590586
model_par_names <- self$metadata()$stan_variables[self$metadata()$stan_variables != "lp__"]
@@ -599,19 +595,10 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
599595
pars <- names(model_variables$parameters[nonzero_length_params])
600596

601597
draws <- posterior::subset_draws(draws, variable = pars)
602-
skeleton <- self$variable_skeleton(transformed_parameters = FALSE,
603-
generated_quantities = FALSE)
604-
par_columns <- !(names(draws) %in% c(".chain", ".iteration", ".draw"))
605-
meta_columns <- !par_columns
606-
unconstrained <- lapply(asplit(draws, 1), function(draw) {
607-
par_list <- utils::relist(as.numeric(draw[par_columns]), skeleton)
608-
self$unconstrain_variables(variables = par_list)
609-
})
610-
611-
unconstrained <- do.call(rbind.data.frame, unconstrained)
598+
unconstrained <- private$model_methods_env_$unconstrain_draws(private$model_methods_env_$model_ptr_, draws)
612599
uncon_names <- private$model_methods_env_$unconstrained_param_names(private$model_methods_env_$model_ptr_, FALSE, FALSE)
613600
names(unconstrained) <- repair_variable_names(uncon_names)
614-
maybe_convert_draws_format(cbind.data.frame(unconstrained, draws[,meta_columns]), format)
601+
maybe_convert_draws_format(unconstrained, format)
615602
}
616603
CmdStanFit$set("public", name = "unconstrain_draws", value = unconstrain_draws)
617604

R/utils.R

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -757,13 +757,6 @@ expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
757757
readLines(system.file("include", "model_methods.cpp",
758758
package = "cmdstanr", mustWork = TRUE)))
759759

760-
if (hessian) {
761-
code <- c("#include <stan/math/mix.hpp>",
762-
code,
763-
readLines(system.file("include", "hessian.cpp",
764-
package = "cmdstanr", mustWork = TRUE)))
765-
}
766-
767760
code <- paste(code, collapse = "\n")
768761
rcpp_source_stan(code, env, verbose)
769762
invisible(NULL)

inst/include/hessian.cpp

Lines changed: 0 additions & 41 deletions
This file was deleted.

inst/include/model_methods.cpp

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <Rcpp.h>
2+
#include <rcpp_eigen_interop.hpp>
23
#include <stan/model/model_base.hpp>
34
#include <stan/model/log_prob_grad.hpp>
45
#include <stan/model/log_prob_propto.hpp>
@@ -26,10 +27,14 @@ using json_data_t = stan::json::json_data;
2627
return std::make_shared<json_data_t>(data_context);
2728
}
2829

30+
stan::model::model_base&
31+
new_model(stan::io::var_context& data_context, unsigned int seed,
32+
std::ostream* msg_stream);
33+
2934
// [[Rcpp::export]]
3035
Rcpp::List model_ptr(std::string data_path, boost::uint32_t seed) {
31-
Rcpp::XPtr<stan_model> ptr(
32-
new stan_model(*var_context(data_path), seed, &Rcpp::Rcout)
36+
Rcpp::XPtr<stan::model::model_base> ptr(
37+
&new_model(*var_context(data_path), seed, &Rcpp::Rcout)
3338
);
3439
Rcpp::XPtr<boost::ecuyer1988> base_rng(new boost::ecuyer1988(seed));
3540
return Rcpp::List::create(
@@ -39,37 +44,56 @@ Rcpp::List model_ptr(std::string data_path, boost::uint32_t seed) {
3944
}
4045

4146
// [[Rcpp::export]]
42-
double log_prob(SEXP ext_model_ptr, std::vector<double> upars,
43-
bool jac_adjust) {
47+
double log_prob(SEXP ext_model_ptr, Eigen::VectorXd upars, bool jac_adjust) {
4448
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
45-
std::vector<int> params_i;
4649
if (jac_adjust) {
47-
return stan::model::log_prob_propto<true>(*ptr.get(), upars, params_i, &Rcpp::Rcout);
50+
return stan::model::log_prob_propto<true>(*ptr.get(), upars, &Rcpp::Rcout);
4851
} else {
49-
return stan::model::log_prob_propto<false>(*ptr.get(), upars, params_i, &Rcpp::Rcout);
52+
return stan::model::log_prob_propto<false>(*ptr.get(), upars, &Rcpp::Rcout);
5053
}
5154
}
5255

5356
// [[Rcpp::export]]
54-
Rcpp::NumericVector grad_log_prob(SEXP ext_model_ptr, std::vector<double> upars,
57+
Rcpp::NumericVector grad_log_prob(SEXP ext_model_ptr, Eigen::VectorXd upars,
5558
bool jac_adjust) {
5659
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
57-
std::vector<double> gradients;
58-
std::vector<int> params_i;
60+
Eigen::VectorXd gradients;
5961

6062
double lp;
6163
if (jac_adjust) {
62-
lp = stan::model::log_prob_grad<true, true>(
63-
*ptr.get(), upars, params_i, gradients);
64+
lp = stan::model::log_prob_grad<true, true>(*ptr.get(), upars, gradients);
6465
} else {
65-
lp = stan::model::log_prob_grad<true, false>(
66-
*ptr.get(), upars, params_i, gradients);
66+
lp = stan::model::log_prob_grad<true, false>(*ptr.get(), upars, gradients);
6767
}
68-
Rcpp::NumericVector grad_rtn = Rcpp::wrap(gradients);
68+
Rcpp::NumericVector grad_rtn(Rcpp::wrap(std::move(gradients)));
6969
grad_rtn.attr("log_prob") = lp;
7070
return grad_rtn;
7171
}
7272

73+
// [[Rcpp::export]]
74+
Rcpp::List hessian(SEXP ext_model_ptr, Eigen::VectorXd upars, bool jacobian) {
75+
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
76+
77+
auto hessian_functor = [&](auto&& x) {
78+
if (jacobian) {
79+
return ptr->log_prob<true, true>(x, 0);
80+
} else {
81+
return ptr->log_prob<true, false>(x, 0);
82+
}
83+
};
84+
85+
double log_prob;
86+
Eigen::VectorXd grad;
87+
Eigen::MatrixXd hessian;
88+
89+
stan::math::internal::finite_diff_hessian_auto(hessian_functor, upars, log_prob, grad, hessian);
90+
91+
return Rcpp::List::create(
92+
Rcpp::Named("log_prob") = log_prob,
93+
Rcpp::Named("grad_log_prob") = grad,
94+
Rcpp::Named("hessian") = hessian);
95+
}
96+
7397
// [[Rcpp::export]]
7498
size_t get_num_upars(SEXP ext_model_ptr) {
7599
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
@@ -95,12 +119,23 @@ Rcpp::List get_param_metadata(SEXP ext_model_ptr) {
95119
}
96120

97121
// [[Rcpp::export]]
98-
std::vector<double> unconstrain_variables(SEXP ext_model_ptr, std::string init_path) {
122+
Eigen::VectorXd unconstrain_variables(SEXP ext_model_ptr, Eigen::VectorXd variables) {
99123
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
100-
std::vector<int> params_i;
101-
std::vector<double> vars;
102-
ptr->transform_inits(*var_context(init_path), params_i, vars, &Rcpp::Rcout);
103-
return vars;
124+
Eigen::VectorXd unconstrained_variables;
125+
ptr->unconstrain_array(variables, unconstrained_variables, &Rcpp::Rcout);
126+
return unconstrained_variables;
127+
}
128+
129+
// [[Rcpp::export]]
130+
Eigen::MatrixXd unconstrain_draws(SEXP ext_model_ptr, Eigen::MatrixXd variables) {
131+
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
132+
Eigen::MatrixXd unconstrained_draws(variables.cols(), variables.rows());
133+
for (int i = 0; i < variables.rows(); i++) {
134+
Eigen::VectorXd unconstrained_variables;
135+
ptr->unconstrain_array(variables.transpose().col(i), unconstrained_variables, &Rcpp::Rcout);
136+
unconstrained_draws.col(i) = unconstrained_variables;
137+
}
138+
return unconstrained_draws.transpose();
104139
}
105140

106141
// [[Rcpp::export]]

tests/testthat/test-model-methods.R

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,10 @@ test_that("Methods error if not compiled", {
3434
)
3535
})
3636

37-
test_that("User warned about higher-order autodiff with hessian", {
38-
skip_if(os_is_wsl())
39-
expect_message(
40-
fit$init_model_methods(hessian = TRUE, verbose = TRUE),
41-
"The hessian method relies on higher-order autodiff which is still experimental. Please report any compilation errors that you encounter",
42-
fixed = TRUE
43-
)
44-
})
4537

4638
test_that("Methods return correct values", {
4739
skip_if(os_is_wsl())
40+
fit$init_model_methods(verbose = TRUE)
4841
lp <- fit$log_prob(unconstrained_variables=c(0.1))
4942
expect_equal(lp, -8.6327599208828509347)
5043

0 commit comments

Comments
 (0)