Skip to content

Commit 722d196

Browse files
Fixes 975 by only removing leftmost array dimension if equal to 1 (#993)
* Fixes 975 by only removing leftmost array dimension if equal to 1 * Update tests, fix windows error --------- Co-authored-by: Andrew Johnson <[email protected]>
1 parent 356fa04 commit 722d196

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

R/args.R

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,24 @@ process_init.default <- function(init, ...) {
10631063
return(init)
10641064
}
10651065

1066+
#' Remove the leftmost dimension if equal to 1
1067+
#' @noRd
1068+
#' @param x An array like object
1069+
.remove_leftmost_dim <- function(x) {
1070+
dims <- dim(x)
1071+
if (length(dims) == 1) {
1072+
return(drop(x))
1073+
} else if (dims[1] == 1) {
1074+
new_dims <- dims[-1]
1075+
# Create a call to subset the array, maintaining all remaining dimensions
1076+
subset_expr <- as.call(c(as.name("["), list(x), 1, rep(TRUE, length(new_dims)), drop = FALSE))
1077+
new_x <- eval(subset_expr)
1078+
return(array(new_x, dim = new_dims))
1079+
} else {
1080+
return(x)
1081+
}
1082+
}
1083+
10661084
#' Write initial values to files if provided as posterior `draws` object
10671085
#' @noRd
10681086
#' @param init A type that inherits the `posterior::draws` class.
@@ -1097,9 +1115,13 @@ process_init.draws <- function(init, num_procs, model_variables = NULL,
10971115
draws_rvar <- posterior::subset_draws(draws_rvar, variable = variable_names)
10981116
inits = lapply(1:num_procs, function(draw_iter) {
10991117
init_i = lapply(variable_names, function(var_name) {
1100-
x = drop(posterior::draws_of(drop(
1101-
posterior::subset_draws(draws_rvar[[var_name]], draw=draw_iter))))
1102-
return(x)
1118+
x = .remove_leftmost_dim(posterior::draws_of(
1119+
posterior::subset_draws(draws_rvar[[var_name]], draw=draw_iter)))
1120+
if (model_variables$parameters[[var_name]]$dimensions == 0) {
1121+
return(as.double(x))
1122+
} else {
1123+
return(x)
1124+
}
11031125
})
11041126
bad_names = unlist(lapply(variable_names, function(var_name) {
11051127
x = drop(posterior::draws_of(drop(
@@ -1295,13 +1317,13 @@ process_init_approx <- function(init, num_procs, model_variables = NULL,
12951317
# Calculate unique draws based on 'lw' using base R functions
12961318
unique_draws = length(unique(draws_df$lw))
12971319
if (num_procs > unique_draws) {
1298-
if (inherits(init, " CmdStanPathfinder ")) {
1320+
if (inherits(init, "CmdStanPathfinder")) {
12991321
algo_name = " Pathfinder "
13001322
extra_msg = " Try running Pathfinder with psis_resample=FALSE."
13011323
} else if (inherits(init, "CmdStanVB")) {
13021324
algo_name = " CmdStanVB "
13031325
extra_msg = ""
1304-
} else if (inherits(init, " CmdStanLaplace ")) {
1326+
} else if (inherits(init, "CmdStanLaplace")) {
13051327
algo_name = " CmdStanLaplace "
13061328
extra_msg = ""
13071329
} else {

tests/testthat/test-model-init.R

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,3 +310,28 @@ test_that("Initial values for single-element containers treated correctly", {
310310
)
311311
)
312312
})
313+
314+
test_that("Pathfinder inits do not drop dimensions", {
315+
modcode <- "
316+
data {
317+
int N;
318+
vector[N] y;
319+
}
320+
321+
parameters {
322+
matrix[N, 1] mu;
323+
matrix[1, N] mu_2;
324+
vector<lower=0>[N] sigma;
325+
}
326+
327+
model {
328+
target += normal_lupdf(y | mu[:, 1], sigma);
329+
target += normal_lupdf(y | mu_2[1], sigma);
330+
}
331+
"
332+
mod <- cmdstan_model(write_stan_file(modcode), force_recompile = TRUE)
333+
data <- list(N = 100, y = rnorm(100))
334+
pf <- mod$pathfinder(data = data, psis_resample = FALSE)
335+
expect_no_error(fit <- mod$sample(data = data, init = pf, chains = 1,
336+
iter_warmup = 100, iter_sampling = 100))
337+
})

0 commit comments

Comments
 (0)