@@ -1063,6 +1063,24 @@ process_init.default <- function(init, ...) {
1063
1063
return (init )
1064
1064
}
1065
1065
1066
+ # ' Remove the leftmost dimension if equal to 1
1067
+ # ' @noRd
1068
+ # ' @param x An array like object
1069
+ .remove_leftmost_dim <- function (x ) {
1070
+ dims <- dim(x )
1071
+ if (length(dims ) == 1 ) {
1072
+ return (drop(x ))
1073
+ } else if (dims [1 ] == 1 ) {
1074
+ new_dims <- dims [- 1 ]
1075
+ # Create a call to subset the array, maintaining all remaining dimensions
1076
+ subset_expr <- as.call(c(as.name(" [" ), list (x ), 1 , rep(TRUE , length(new_dims )), drop = FALSE ))
1077
+ new_x <- eval(subset_expr )
1078
+ return (array (new_x , dim = new_dims ))
1079
+ } else {
1080
+ return (x )
1081
+ }
1082
+ }
1083
+
1066
1084
# ' Write initial values to files if provided as posterior `draws` object
1067
1085
# ' @noRd
1068
1086
# ' @param init A type that inherits the `posterior::draws` class.
@@ -1097,9 +1115,13 @@ process_init.draws <- function(init, num_procs, model_variables = NULL,
1097
1115
draws_rvar <- posterior :: subset_draws(draws_rvar , variable = variable_names )
1098
1116
inits = lapply(1 : num_procs , function (draw_iter ) {
1099
1117
init_i = lapply(variable_names , function (var_name ) {
1100
- x = drop(posterior :: draws_of(drop(
1101
- posterior :: subset_draws(draws_rvar [[var_name ]], draw = draw_iter ))))
1102
- return (x )
1118
+ x = .remove_leftmost_dim(posterior :: draws_of(
1119
+ posterior :: subset_draws(draws_rvar [[var_name ]], draw = draw_iter )))
1120
+ if (model_variables $ parameters [[var_name ]]$ dimensions == 0 ) {
1121
+ return (as.double(x ))
1122
+ } else {
1123
+ return (x )
1124
+ }
1103
1125
})
1104
1126
bad_names = unlist(lapply(variable_names , function (var_name ) {
1105
1127
x = drop(posterior :: draws_of(drop(
@@ -1295,13 +1317,13 @@ process_init_approx <- function(init, num_procs, model_variables = NULL,
1295
1317
# Calculate unique draws based on 'lw' using base R functions
1296
1318
unique_draws = length(unique(draws_df $ lw ))
1297
1319
if (num_procs > unique_draws ) {
1298
- if (inherits(init , " CmdStanPathfinder " )) {
1320
+ if (inherits(init , " CmdStanPathfinder" )) {
1299
1321
algo_name = " Pathfinder "
1300
1322
extra_msg = " Try running Pathfinder with psis_resample=FALSE."
1301
1323
} else if (inherits(init , " CmdStanVB" )) {
1302
1324
algo_name = " CmdStanVB "
1303
1325
extra_msg = " "
1304
- } else if (inherits(init , " CmdStanLaplace " )) {
1326
+ } else if (inherits(init , " CmdStanLaplace" )) {
1305
1327
algo_name = " CmdStanLaplace "
1306
1328
extra_msg = " "
1307
1329
} else {
0 commit comments