Skip to content

Commit cdd62a0

Browse files
authored
Merge pull request #702 from andrjohns/expose-stan-functions
Add optional method for exposing stan functions to R
2 parents f01934c + b8707ef commit cdd62a0

File tree

8 files changed

+239
-15
lines changed

8 files changed

+239
-15
lines changed

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,6 @@ Suggests:
4848
rmarkdown,
4949
testthat (>= 2.1.0),
5050
Rcpp,
51-
RcppEigen
51+
RcppEigen,
52+
decor
5253
VignetteBuilder: knitr

R/args.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ CmdStanArgs <- R6::R6Class(
2727
stan_file = NULL,
2828
stan_code = NULL,
2929
model_methods_env = NULL,
30+
standalone_env = NULL,
3031
exe_file,
3132
proc_ids,
3233
method_args,
@@ -45,6 +46,7 @@ CmdStanArgs <- R6::R6Class(
4546
self$stan_code <- stan_code
4647
self$exe_file <- exe_file
4748
self$model_methods_env <- model_methods_env
49+
self$standalone_env <- standalone_env
4850
self$proc_ids <- proc_ids
4951
self$data_file <- data_file
5052
self$seed <- seed

R/fit.R

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@ CmdStanFit <- R6::R6Class(
99
classname = "CmdStanFit",
1010
public = list(
1111
runset = NULL,
12+
functions = NULL,
1213
initialize = function(runset) {
1314
checkmate::assert_r6(runset, classes = "CmdStanRun")
1415
self$runset <- runset
1516
private$model_methods_env_ <- runset$model_methods_env()
17+
self$functions <- runset$standalone_env()
1618

1719
if (!is.null(private$model_methods_env_$model_ptr)) {
1820
initialize_model_pointer(private$model_methods_env_, self$data_file(), 0)
@@ -278,6 +280,30 @@ init <- function() {
278280
}
279281
CmdStanFit$set("public", name = "init", value = init)
280282

283+
expose_functions <- function(global = FALSE, verbose = FALSE) {
284+
require_suggested_package("Rcpp")
285+
require_suggested_package("RcppEigen")
286+
require_suggested_package("decor")
287+
if (self$functions$compiled) {
288+
if (!global) {
289+
message("Functions already compiled, nothing to do!")
290+
} else {
291+
message("Functions already compiled, copying to global environment")
292+
# Create reference to global environment, avoids NOTE about assigning to global
293+
pos <- 1
294+
envir = as.environment(pos)
295+
lapply(self$functions$fun_names, function(fun_name) {
296+
assign(fun_name, get(fun_name, self$functions), envir)
297+
})
298+
}
299+
} else {
300+
message("Compiling standalone functions...")
301+
expose_functions(self$functions, verbose, global)
302+
}
303+
invisible(NULL)
304+
}
305+
CmdStanFit$set("public", name = "expose_functions", value = expose_functions)
306+
281307
#' Compile additional methods for accessing the model log-probability function
282308
#' and parameter constraining and unconstraining. This requires the `Rcpp` package.
283309
#'

R/model.R

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ CmdStanModel <- R6::R6Class(
219219
precompile_cpp_options_ = NULL,
220220
precompile_stanc_options_ = NULL,
221221
precompile_include_paths_ = NULL,
222-
variables_ = NULL
222+
variables_ = NULL,
223+
standalone_env_ = NULL
223224
),
224225
public = list(
225226
initialize = function(stan_file = NULL, exe_file = NULL, compile, ...) {
@@ -387,6 +388,7 @@ CmdStanModel <- R6::R6Class(
387388
#' (`log_prob()`, `grad_log_prob()`, `constrain_pars()`, `unconstrain_pars()`)
388389
#' @param compile_hessian_method (logical) Should the (experimental) `hessian()` method be
389390
#' be compiled with the model methods?
391+
#' @param compile_standalone (logical) Should functions in the Stan model be compiled for used in R?
390392
#' @param threads Deprecated and will be removed in a future release. Please
391393
#' turn on threading via `cpp_options = list(stan_threads = TRUE)` instead.
392394
#'
@@ -438,6 +440,7 @@ compile <- function(quiet = TRUE,
438440
force_recompile = getOption("cmdstanr_force_recompile", default = FALSE),
439441
compile_model_methods = FALSE,
440442
compile_hessian_method = FALSE,
443+
compile_standalone = FALSE,
441444
#deprecated
442445
threads = FALSE) {
443446
if (length(self$stan_file()) == 0) {
@@ -557,6 +560,13 @@ compile <- function(quiet = TRUE,
557560
stanc_built_options <- c(stanc_built_options, paste0("--", option_name, "=", "'", stanc_options[[i]], "'"))
558561
}
559562
}
563+
stancflags_standalone <- c("--standalone-functions", stancflags_val, stanc_built_options)
564+
private$standalone_env_ <- new.env()
565+
private$standalone_env_$compiled <- FALSE
566+
private$standalone_env_$hpp_code <- get_standalone_hpp(temp_stan_file, stancflags_standalone)
567+
if (compile_standalone) {
568+
expose_functions(private$standalone_env_, !quiet)
569+
}
560570
stancflags_val <- paste0("STANCFLAGS += ", stancflags_val, paste0(" ", stanc_built_options, collapse = " "))
561571
withr::with_path(
562572
c(
@@ -1117,6 +1127,7 @@ sample <- function(data = NULL,
11171127
stan_file = self$stan_file(),
11181128
stan_code = suppressWarnings(self$code()),
11191129
model_methods_env = private$model_methods_env_,
1130+
standalone_env = private$standalone_env_,
11201131
model_name = self$model_name(),
11211132
exe_file = self$exe_file(),
11221133
proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE),
@@ -1275,6 +1286,7 @@ sample_mpi <- function(data = NULL,
12751286
stan_file = self$stan_file(),
12761287
stan_code = suppressWarnings(self$code()),
12771288
model_methods_env = private$model_methods_env_,
1289+
standalone_env = private$standalone_env_,
12781290
model_name = self$model_name(),
12791291
exe_file = self$exe_file(),
12801292
proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE),
@@ -1387,6 +1399,7 @@ optimize <- function(data = NULL,
13871399
stan_file = self$stan_file(),
13881400
stan_code = suppressWarnings(self$code()),
13891401
model_methods_env = private$model_methods_env_,
1402+
standalone_env = private$standalone_env_,
13901403
model_name = self$model_name(),
13911404
exe_file = self$exe_file(),
13921405
proc_ids = 1,
@@ -1505,6 +1518,7 @@ variational <- function(data = NULL,
15051518
stan_file = self$stan_file(),
15061519
stan_code = suppressWarnings(self$code()),
15071520
model_methods_env = private$model_methods_env_,
1521+
standalone_env = private$standalone_env_,
15081522
model_name = self$model_name(),
15091523
exe_file = self$exe_file(),
15101524
proc_ids = 1,
@@ -1622,6 +1636,7 @@ generate_quantities <- function(fitted_params,
16221636
stan_file = self$stan_file(),
16231637
stan_code = suppressWarnings(self$code()),
16241638
model_methods_env = private$model_methods_env_,
1639+
standalone_env = private$standalone_env_,
16251640
model_name = self$model_name(),
16261641
exe_file = self$exe_file(),
16271642
proc_ids = seq_along(fitted_params_files),
@@ -1686,6 +1701,7 @@ diagnose <- function(data = NULL,
16861701
stan_file = self$stan_file(),
16871702
stan_code = suppressWarnings(self$code()),
16881703
model_methods_env = private$model_methods_env_,
1704+
standalone_env = private$standalone_env_,
16891705
model_name = self$model_name(),
16901706
exe_file = self$exe_file(),
16911707
proc_ids = 1,

R/run.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ CmdStanRun <- R6::R6Class(
3636
exe_file = function() self$args$exe_file,
3737
stan_code = function() self$args$stan_code,
3838
model_methods_env = function() self$args$model_methods_env,
39+
standalone_env = function() self$args$standalone_env,
3940
model_name = function() self$args$model_name,
4041
method = function() self$args$method,
4142
data_file = function() self$args$data_file,

R/utils.R

Lines changed: 106 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,7 @@ as_mcmc.list <- function(x) {
513513
return(mcmc_list)
514514
}
515515

516+
# Model methods & expose_functions helpers ------------------------------------------------------
516517
get_cmdstan_flags <- function(flag_name) {
517518
cmdstan_path <- cmdstanr::cmdstan_path()
518519
flags <- processx::run(
@@ -557,19 +558,7 @@ get_cmdstan_flags <- function(flag_name) {
557558
paste(flags, collapse = " ")
558559
}
559560

560-
expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
561-
code <- c(env$hpp_code_,
562-
readLines(system.file("include", "model_methods.cpp",
563-
package = "cmdstanr", mustWork = TRUE)))
564-
565-
if (hessian) {
566-
code <- c(code,
567-
readLines(system.file("include", "hessian.cpp",
568-
package = "cmdstanr", mustWork = TRUE)))
569-
}
570-
571-
code <- paste(code, collapse = "\n")
572-
561+
rcpp_source_stan <- function(code, env, verbose = FALSE) {
573562
cxxflags <- get_cmdstan_flags("CXXFLAGS")
574563
libs <- c("LDLIBS", "LIBSUNDIALS", "TBB_TARGETS", "LDFLAGS_TBB")
575564
libs <- paste(sapply(libs, get_cmdstan_flags), collapse = "")
@@ -592,6 +581,22 @@ expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
592581
invisible(NULL)
593582
}
594583

584+
expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
585+
code <- c(env$hpp_code_,
586+
readLines(system.file("include", "model_methods.cpp",
587+
package = "cmdstanr", mustWork = TRUE)))
588+
589+
if (hessian) {
590+
code <- c(code,
591+
readLines(system.file("include", "hessian.cpp",
592+
package = "cmdstanr", mustWork = TRUE)))
593+
}
594+
595+
code <- paste(code, collapse = "\n")
596+
rcpp_source_stan(code, env, verbose)
597+
invisible(NULL)
598+
}
599+
595600
initialize_model_pointer <- function(env, data, seed = 0) {
596601
ptr_and_rng <- env$model_ptr(data, seed)
597602
env$model_ptr_ <- ptr_and_rng$model_ptr
@@ -609,3 +614,91 @@ create_skeleton <- function(model_variables) {
609614
})
610615
stats::setNames(skeleton, names(model_pars))
611616
}
617+
618+
get_standalone_hpp <- function(stan_file, stancflags) {
619+
status <- withr::with_path(
620+
c(
621+
toolchain_PATH_env_var(),
622+
tbb_path()
623+
),
624+
wsl_compatible_run(
625+
command = stanc_cmd(),
626+
args = c(stan_file,
627+
stancflags),
628+
wd = cmdstan_path(),
629+
error_on_status = FALSE
630+
)
631+
)
632+
if (status$status == 0) {
633+
name <- strip_ext(basename(stan_file))
634+
path <- dirname(stan_file)
635+
hpp_path <- file.path(path, paste0(name, ".hpp"))
636+
hpp <- readLines(hpp_path)
637+
unlink(hpp_path)
638+
hpp
639+
} else {
640+
invisible(NULL)
641+
}
642+
}
643+
644+
# Construct the plain return type for a standalone function by
645+
# looking up the return type of the functor declaration and replacing
646+
# the template types (i.e., T0__) with double
647+
get_plain_rtn <- function(fun_body, model_lines) {
648+
fun_props <- decor::parse_cpp_function(paste(fun_body[-1], collapse = "\n"))
649+
struct_start <- grep(paste0("struct ", fun_props$name, "_functor"), model_lines)
650+
struct_op_start <- grep("operator()", model_lines[-(1:struct_start)])[1] + struct_start
651+
652+
struct_rtn <- grep("nullptr>", model_lines[struct_start:struct_op_start], fixed = TRUE) + struct_start
653+
654+
rtn_type <- paste0(model_lines[struct_rtn:struct_op_start], collapse = " ")
655+
rm_trailing_nullptr <- gsub(".*nullptr>[^,]", "", rtn_type)
656+
rm_operator <- gsub("operator().*", "", rtn_type)
657+
repl_dbl <- gsub("T[0-9*]__", "double", rm_operator)
658+
gsub("(^\\s|\\s$)", "", repl_dbl)
659+
}
660+
661+
# Prepare the c++ code for a standalone function so that it can be exported to R:
662+
# - Replace the auto return type with the plain type
663+
# - Add Rcpp::export attribute
664+
# - Remove the pstream__ argument and pass Rcpp::Rcout by default
665+
# - Replace the boost::ecuyer1988& base_rng__ argument with an integer seed argument
666+
# that instantiates an RNG
667+
prep_fun_cpp <- function(fun_body, model_lines) {
668+
fun_body <- gsub("auto", get_plain_rtn(fun_body, model_lines), fun_body)
669+
fun_body <- gsub("// [[stan::function]]", "// [[Rcpp::export]]", fun_body, fixed = TRUE)
670+
fun_body <- gsub("std::ostream* pstream__ = nullptr", "", fun_body, fixed = TRUE)
671+
fun_body <- gsub("boost::ecuyer1988& base_rng__", "size_t seed = 0", fun_body, fixed = TRUE)
672+
fun_body <- gsub("base_rng__,", "*(new boost::ecuyer1988(seed)),", fun_body, fixed = TRUE)
673+
fun_body <- gsub("pstream__", "&Rcpp::Rcout", fun_body, fixed = TRUE)
674+
fun_body <- paste(fun_body, collapse = "\n")
675+
gsub(pattern = ",\\s*)", replacement = ")", fun_body)
676+
}
677+
678+
expose_functions <- function(env, verbose = FALSE, global = FALSE) {
679+
funs <- grep("// [[stan::function]]", env$hpp_code, fixed = TRUE)
680+
funs <- c(funs, length(env$hpp_code))
681+
682+
stan_funs <- sapply(seq_len(length(funs) - 1), function(ind) {
683+
fun_body <- env$hpp_code[funs[ind]:(funs[ind + 1] - 1)]
684+
prep_fun_cpp(fun_body, env$hpp_code)
685+
})
686+
687+
env$fun_names <- sapply(stan_funs, function(fun) {
688+
decor::parse_cpp_function(fun, is_attribute = TRUE)$name
689+
})
690+
691+
mod_stan_funs <- paste(c(
692+
env$hpp_code[1:(funs[1] - 1)],
693+
"#include <RcppEigen.h>",
694+
"// [[Rcpp::depends(RcppEigen)]]",
695+
stan_funs),
696+
collapse = "\n")
697+
if (global) {
698+
rcpp_source_stan(mod_stan_funs, globalenv(), verbose)
699+
} else {
700+
rcpp_source_stan(mod_stan_funs, env, verbose)
701+
}
702+
env$compiled <- TRUE
703+
invisible(NULL)
704+
}

man/model-method-compile.Rd

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)