Skip to content

Commit a9c898b

Browse files
authored
Merge pull request #832 from stan-dev/variable-skeleton-dims
Fix variable_skeleton() with containers
2 parents 2b04e4f + 97d1142 commit a9c898b

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

R/utils.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,11 @@ create_skeleton <- function(param_metadata, model_variables,
780780
names(model_variables$generated_quantities))
781781
}
782782
lapply(param_metadata[target_params], function(par_dims) {
783-
array(0, dim = ifelse(length(par_dims) == 0, 1, par_dims))
783+
if ((length(par_dims) == 0)) {
784+
array(0, dim = 1)
785+
} else {
786+
array(0, dim = par_dims)
787+
}
784788
})
785789
}
786790

tests/testthat/test-model-methods.R

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,38 @@ test_that("Model methods can be initialised for models with no data", {
277277
expect_no_error(fit <- mod$sample())
278278
expect_equal(fit$log_prob(5), -12.5)
279279
})
280+
281+
test_that("Variable skeleton returns correct dimensions for matrices", {
282+
skip_if(os_is_wsl())
283+
284+
stan_file <- write_stan_file("
285+
data {
286+
int N;
287+
int K;
288+
}
289+
parameters {
290+
real x_real;
291+
matrix[N,K] x_mat;
292+
vector[K] x_vec;
293+
row_vector[K] x_rowvec;
294+
}
295+
model {
296+
x_real ~ std_normal();
297+
}")
298+
mod <- cmdstan_model(stan_file, compile_model_methods = TRUE,
299+
force_recompile = TRUE)
300+
N <- 4
301+
K <- 3
302+
fit <- mod$sample(data = list(N = N, K = K), chains = 1,
303+
iter_warmup = 1, iter_sampling = 1)
304+
305+
target_skeleton <- list(
306+
x_real = array(0, dim = 1),
307+
x_mat = array(0, dim = c(N, K)),
308+
x_vec = array(0, dim = K),
309+
x_rowvec = array(0, dim = K)
310+
)
311+
312+
expect_equal(fit$variable_skeleton(),
313+
target_skeleton)
314+
})

0 commit comments

Comments
 (0)