diff --git a/R/args.R b/R/args.R index 7794510e4..a2a31af08 100644 --- a/R/args.R +++ b/R/args.R @@ -826,6 +826,13 @@ process_init_list <- function(init, num_procs, model_variables = NULL) { if (!all(is_parameter_value_supplied)) { missing_parameter_values[[i]] <- parameter_names[!is_parameter_value_supplied] } + for (par_name in parameter_names[is_parameter_value_supplied]) { + # Make sure that initial values for single-element containers don't get + # unboxed when writing to JSON + if (model_variables$parameters[[par_name]]$dimensions == 1 && length(init[[i]][[par_name]]) == 1) { + init[[i]][[par_name]] <- array(init[[i]][[par_name]], dim = 1) + } + } } if (length(missing_parameter_values) > 0) { warning_message <- c( diff --git a/tests/testthat/test-model-init.R b/tests/testthat/test-model-init.R index 221c8dfb3..cbf66c264 100644 --- a/tests/testthat/test-model-init.R +++ b/tests/testthat/test-model-init.R @@ -262,3 +262,25 @@ test_that("print message if not all parameters are initialized", { fixed = TRUE ) }) + +test_that("Initial values for single-element containers treated correctly", { + modcode <- " + data { + real y_mean; + } + parameters { + vector[1] y; + } + model { + y_mean ~ normal(y[1], 1); + } + " + mod <- cmdstan_model(write_stan_file(modcode), force_recompile = TRUE) + expect_no_error( + fit <- mod$sample( + data = list(y_mean = 0), + init = list(list(y = c(0))), + chains = 1 + ) + ) +})