Skip to content

Commit 973e5ba

Browse files
committed
Add support/testing for exposing functions with tuples
1 parent 77a1443 commit 973e5ba

File tree

3 files changed

+136
-2
lines changed

3 files changed

+136
-2
lines changed

R/utils.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,8 @@ get_cmdstan_flags <- function(flag_name) {
730730

731731
rcpp_source_stan <- function(code, env, verbose = FALSE) {
732732
cxxflags <- get_cmdstan_flags("CXXFLAGS")
733+
cmdstanr_includes <- system.file("include", package = "cmdstanr", mustWork = TRUE)
734+
cmdstanr_includes <- paste0(" -I\"", cmdstanr_includes,"\"")
733735
libs <- c("LDLIBS", "LIBSUNDIALS", "TBB_TARGETS", "LDFLAGS_TBB")
734736
libs <- paste(sapply(libs, get_cmdstan_flags), collapse = " ")
735737
if (.Platform$OS.type == "windows") {
@@ -742,7 +744,7 @@ rcpp_source_stan <- function(code, env, verbose = FALSE) {
742744
c(
743745
USE_CXX14 = 1,
744746
PKG_CPPFLAGS = ifelse(cmdstan_version() <= "2.30.1", "-DCMDSTAN_JSON", ""),
745-
PKG_CXXFLAGS = cxxflags,
747+
PKG_CXXFLAGS = paste0(cxxflags, cmdstanr_includes, collapse = " "),
746748
PKG_LIBS = libs
747749
),
748750
Rcpp::sourceCpp(code = code, env = env, verbose = verbose)
@@ -830,7 +832,8 @@ get_function_name <- function(fun_start, fun_end, model_lines) {
830832
"int",
831833
"double",
832834
"Eigen::Matrix<(.*)>",
833-
"std::vector<(.*)>"
835+
"std::vector<(.*)>",
836+
"std::tuple<(.*)>"
834837
)
835838
pattern <- paste0(
836839
# Only match if the type occurs at start of string
@@ -923,6 +926,7 @@ compile_functions <- function(env, verbose = FALSE, global = FALSE) {
923926

924927
mod_stan_funs <- paste(c(
925928
env$hpp_code[1:(funs[1] - 1)],
929+
"#include <rcpp_tuple_interop.hpp>",
926930
"#include <RcppEigen.h>",
927931
"// [[Rcpp::depends(RcppEigen)]]",
928932
stan_funs),

inst/include/rcpp_tuple_interop.hpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#include <stan/math/prim/functor/apply.hpp>
2+
#include <Rcpp.h>
3+
4+
namespace Rcpp {
5+
namespace traits {
6+
/**
7+
* The Rcpp::traits::Exporter class is the implementation used when calling
8+
* Rcpp::as<T>() to convert an R object (SEXP) to the requested c++ type.
9+
*/
10+
template <typename... T>
11+
class Exporter<std::tuple<T...>> {
12+
private:
13+
Rcpp::List list_x;
14+
15+
template<std::size_t... I>
16+
auto get_impl(std::index_sequence<I...> i) {
17+
return std::make_tuple(
18+
Rcpp::as<T>(list_x[I].get())...
19+
);
20+
}
21+
22+
public:
23+
Exporter(SEXP x) : list_x(x) { }
24+
std::tuple<T...> get() {
25+
return get_impl(std::index_sequence_for<T...>{});
26+
}
27+
};
28+
}
29+
30+
/**
31+
* The Rcpp::wrap class is used to convert a C++ type to an R object type.
32+
* Rather than implement anything bespoke for tuples we simply return an R list.
33+
*/
34+
template <typename... T>
35+
SEXP wrap(const std::tuple<T...>& x) {
36+
return stan::math::apply([](const auto&... args) {
37+
return Rcpp::List::create(Rcpp::wrap(args)...);
38+
}, x);
39+
}
40+
}

tests/testthat/test-model-expose-functions.R

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,30 @@ functions {
1515
array[] vector rtn_vec_array(array[] vector x) { return x; }
1616
array[] row_vector rtn_rowvec_array(array[] row_vector x) { return x; }
1717
array[] matrix rtn_matrix_array(array[] matrix x) { return x; }
18+
19+
tuple(int, int) rtn_tuple_int(tuple(int, int) x) { return x; }
20+
tuple(real, real) rtn_tuple_real(tuple(real, real) x) { return x; }
21+
tuple(vector, vector) rtn_tuple_vec(tuple(vector, vector) x) { return x; }
22+
tuple(row_vector, row_vector) rtn_tuple_rowvec(tuple(row_vector, row_vector) x) { return x; }
23+
tuple(matrix, matrix) rtn_tuple_matrix(tuple(matrix, matrix) x) { return x; }
24+
25+
tuple(array[] int, array[] int) rtn_tuple_int_array(tuple(array[] int, array[] int) x) { return x; }
26+
tuple(array[] real, array[] real) rtn_tuple_real_array(tuple(array[] real, array[] real) x) { return x; }
27+
tuple(array[] vector, array[] vector) rtn_tuple_vec_array(tuple(array[] vector, array[] vector) x) { return x; }
28+
tuple(array[] row_vector, array[] row_vector) rtn_tuple_rowvec_array(tuple(array[] row_vector, array[] row_vector) x) { return x; }
29+
tuple(array[] matrix, array[] matrix) rtn_tuple_matrix_array(tuple(array[] matrix, array[] matrix) x) { return x; }
30+
31+
tuple(int, tuple(int, int)) rtn_nest_tuple_int(tuple(int, tuple(int, int)) x) { return x; }
32+
tuple(int, tuple(real, real)) rtn_nest_tuple_real(tuple(int, tuple(real, real)) x) { return x; }
33+
tuple(int, tuple(vector, vector)) rtn_nest_tuple_vec(tuple(int, tuple(vector, vector)) x) { return x; }
34+
tuple(int, tuple(row_vector, row_vector)) rtn_nest_tuple_rowvec(tuple(int, tuple(row_vector, row_vector)) x) { return x; }
35+
tuple(int, tuple(matrix, matrix)) rtn_nest_tuple_matrix(tuple(int, tuple(matrix, matrix)) x) { return x; }
36+
37+
tuple(int, tuple(array[] int, array[] int)) rtn_nest_tuple_int_array(tuple(int, tuple(array[] int, array[] int)) x) { return x; }
38+
tuple(int, tuple(array[] real, array[] real)) rtn_nest_tuple_real_array(tuple(int, tuple(array[] real, array[] real)) x) { return x; }
39+
tuple(int, tuple(array[] vector, array[] vector)) rtn_nest_tuple_vec_array(tuple(int, tuple(array[] vector, array[] vector)) x) { return x; }
40+
tuple(int, tuple(array[] row_vector, array[] row_vector)) rtn_nest_tuple_rowvec_array(tuple(int, tuple(array[] row_vector, array[] row_vector)) x) { return x; }
41+
tuple(int, tuple(array[] matrix, array[] matrix)) rtn_nest_tuple_matrix_array(tuple(int, tuple(array[] matrix, array[] matrix)) x) { return x; }
1842
}"
1943
stan_prog <- paste(function_decl,
2044
paste(readLines(testing_stan_file("bernoulli")),
@@ -35,9 +59,13 @@ test_that("Functions can be exposed in model object", {
3559
test_that("Functions handle types correctly", {
3660
skip_if(os_is_wsl())
3761

62+
### Scalar
63+
3864
expect_equal(mod$functions$rtn_int(10), 10)
3965
expect_equal(mod$functions$rtn_real(1.67), 1.67)
4066

67+
### Container
68+
4169
vec <- c(1.2,234,0.3,-0.4)
4270
rowvec <- t(vec)
4371
matrix <- matrix(c(2.11, -6.35, 4.87, -0.9871), nrow = 2, ncol = 2)
@@ -48,13 +76,75 @@ test_that("Functions handle types correctly", {
4876
expect_equal(mod$functions$rtn_int_array(1:5), 1:5)
4977
expect_equal(mod$functions$rtn_real_array(vec), vec)
5078

79+
### Array of Container
80+
5181
vec_array <- list(vec, vec * 2, vec + 0.1)
5282
rowvec_array <- list(rowvec, rowvec * 2, rowvec + 0.1)
5383
matrix_array <- list(matrix, matrix * 2, matrix + 0.1)
5484

5585
expect_equal(mod$functions$rtn_vec_array(vec_array), vec_array)
5686
expect_equal(mod$functions$rtn_rowvec_array(rowvec_array), rowvec_array)
5787
expect_equal(mod$functions$rtn_matrix_array(matrix_array), matrix_array)
88+
89+
### Tuple of Scalar
90+
91+
tuple_int <- list(10, 35)
92+
tuple_dbl <- list(31.87, -19.09)
93+
expect_equal(mod$functions$rtn_tuple_int(tuple_int), tuple_int)
94+
expect_equal(mod$functions$rtn_tuple_real(tuple_dbl), tuple_dbl)
95+
96+
### Tuple of Container
97+
98+
tuple_vec <- list(vec, vec * 12)
99+
tuple_rowvec <- list(rowvec, rowvec * 0.5)
100+
tuple_matrix <- list(matrix, matrix * 0.23)
101+
tuple_int_array <- list(1:10, -3:2)
102+
103+
expect_equal(mod$functions$rtn_tuple_vec(tuple_vec), tuple_vec)
104+
expect_equal(mod$functions$rtn_tuple_rowvec(tuple_rowvec), tuple_rowvec)
105+
expect_equal(mod$functions$rtn_tuple_matrix(tuple_matrix), tuple_matrix)
106+
expect_equal(mod$functions$rtn_tuple_int_array(tuple_int_array), tuple_int_array)
107+
expect_equal(mod$functions$rtn_tuple_real_array(tuple_vec), tuple_vec)
108+
109+
### Tuple of Container Arrays
110+
111+
tuple_vec_array <- list(vec_array, vec_array)
112+
tuple_rowvec_array <- list(rowvec_array, rowvec_array)
113+
tuple_matrix_array <- list(matrix_array, matrix_array)
114+
115+
expect_equal(mod$functions$rtn_tuple_vec_array(tuple_vec_array), tuple_vec_array)
116+
expect_equal(mod$functions$rtn_tuple_rowvec_array(tuple_rowvec_array), tuple_rowvec_array)
117+
expect_equal(mod$functions$rtn_tuple_matrix_array(tuple_matrix_array), tuple_matrix_array)
118+
119+
### Nested Tuple of Scalar
120+
121+
nest_tuple_int <- list(10, tuple_int)
122+
nest_tuple_dbl <- list(31, tuple_dbl)
123+
expect_equal(mod$functions$rtn_nest_tuple_int(nest_tuple_int), nest_tuple_int)
124+
expect_equal(mod$functions$rtn_nest_tuple_real(nest_tuple_dbl), nest_tuple_dbl)
125+
126+
### Nested Tuple of Container
127+
128+
nest_tuple_vec <- list(12, tuple_vec)
129+
nest_tuple_rowvec <- list(2, tuple_rowvec)
130+
nest_tuple_matrix <- list(-23, tuple_matrix)
131+
nest_tuple_int_array <- list(21, tuple_int_array)
132+
133+
expect_equal(mod$functions$rtn_nest_tuple_vec(nest_tuple_vec), nest_tuple_vec)
134+
expect_equal(mod$functions$rtn_nest_tuple_rowvec(nest_tuple_rowvec), nest_tuple_rowvec)
135+
expect_equal(mod$functions$rtn_nest_tuple_matrix(nest_tuple_matrix), nest_tuple_matrix)
136+
expect_equal(mod$functions$rtn_nest_tuple_int_array(nest_tuple_int_array), nest_tuple_int_array)
137+
expect_equal(mod$functions$rtn_nest_tuple_real_array(nest_tuple_vec), nest_tuple_vec)
138+
139+
### Nested Tuple of Container Arrays
140+
141+
nest_tuple_vec_array <- list(-21, tuple_vec_array)
142+
nest_tuple_rowvec_array <- list(1000, tuple_rowvec_array)
143+
nest_tuple_matrix_array <- list(0, tuple_matrix_array)
144+
145+
expect_equal(mod$functions$rtn_nest_tuple_vec_array(nest_tuple_vec_array), nest_tuple_vec_array)
146+
expect_equal(mod$functions$rtn_nest_tuple_rowvec_array(nest_tuple_rowvec_array), nest_tuple_rowvec_array)
147+
expect_equal(mod$functions$rtn_nest_tuple_matrix_array(nest_tuple_matrix_array), nest_tuple_matrix_array)
58148
})
59149

60150
test_that("Functions can be exposed in fit object", {

0 commit comments

Comments
 (0)