Skip to content

Commit d3b455f

Browse files
update pathfinder args for psis_resample and lp_calculate (#903)
* update pathfinder args for psis_resample and lp_calculate --------- Co-authored-by: Andrew Johnson <[email protected]>
1 parent 3c7a1a9 commit d3b455f

File tree

4 files changed

+50
-4
lines changed

4 files changed

+50
-4
lines changed

R/args.R

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,9 @@ PathfinderArgs <- R6::R6Class(
566566
num_paths = NULL,
567567
max_lbfgs_iters = NULL,
568568
num_elbo_draws = NULL,
569-
save_single_paths = NULL) {
569+
save_single_paths = NULL,
570+
psis_resample = NULL,
571+
calculate_lp = NULL) {
570572
self$init_alpha <- init_alpha
571573
self$tol_obj <- tol_obj
572574
self$tol_rel_obj <- tol_rel_obj
@@ -580,6 +582,8 @@ PathfinderArgs <- R6::R6Class(
580582
self$max_lbfgs_iters <- max_lbfgs_iters
581583
self$num_elbo_draws <- num_elbo_draws
582584
self$save_single_paths <- save_single_paths
585+
self$psis_resample <- psis_resample
586+
self$calculate_lp <- calculate_lp
583587
invisible(self)
584588
},
585589

@@ -608,7 +612,9 @@ PathfinderArgs <- R6::R6Class(
608612
.make_arg("num_paths"),
609613
.make_arg("max_lbfgs_iters"),
610614
.make_arg("num_elbo_draws"),
611-
.make_arg("save_single_paths")
615+
.make_arg("save_single_paths"),
616+
.make_arg("psis_resample"),
617+
.make_arg("calculate_lp")
612618
)
613619
new_args <- do.call(c, new_args)
614620
c(args, new_args)
@@ -966,6 +972,16 @@ validate_pathfinder_args <- function(self) {
966972
if (!is.null(self$save_single_paths)) {
967973
self$save_single_paths <- 0
968974
}
975+
if (!is.null(self$psis_resample) && is.logical(self$psis_resample)) {
976+
self$psis_resample = as.integer(self$psis_resample)
977+
}
978+
checkmate::assert_integerish(self$psis_resample, null.ok = TRUE,
979+
lower = 0, upper = 1, len = 1)
980+
if (!is.null(self$calculate_lp) && is.logical(self$calculate_lp)) {
981+
self$calculate_lp = as.integer(self$calculate_lp)
982+
}
983+
checkmate::assert_integerish(self$calculate_lp, null.ok = TRUE,
984+
lower = 0, upper = 1, len = 1)
969985

970986

971987
# check args only available for lbfgs and bfgs

R/model.R

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1887,6 +1887,17 @@ CmdStanModel$set("public", name = "variational", value = variational)
18871887
#' calculating the ELBO of the approximation at each iteration of LBFGS.
18881888
#' @param save_single_paths (logical) Whether to save the results of single
18891889
#' pathfinder runs in multi-pathfinder.
1890+
#' @param psis_resample (logical) Whether to perform pareto smoothed importance sampling.
1891+
#' If `TRUE`, the number of draws returned will be equal to `draws`.
1892+
#' If `FALSE`, the number of draws returned will be equal to `single_path_draws * num_paths`.
1893+
#' @param calculate_lp (logical) Whether to calculate the log probability of the draws.
1894+
#' If `TRUE`, the log probability will be calculated and given in the output.
1895+
#' If `FALSE`, the log probability will only be returned for draws used to determine the
1896+
#' ELBO in the pathfinder steps. All other draws will have a log probability of `NA`.
1897+
#' A value of `FALSE` will also turn off pareto smoothed importance sampling as the
1898+
#' lp calculation is needed for PSIS.
1899+
#' @param save_single_paths (logical) Whether to save the results of single
1900+
#' pathfinder runs in multi-pathfinder.
18901901
#' @return A [`CmdStanPathfinder`] object.
18911902
#'
18921903
#' @template seealso-docs
@@ -1915,6 +1926,8 @@ pathfinder <- function(data = NULL,
19151926
max_lbfgs_iters = NULL,
19161927
num_elbo_draws = NULL,
19171928
save_single_paths = NULL,
1929+
psis_resample = NULL,
1930+
calculate_lp = NULL,
19181931
show_messages = TRUE,
19191932
show_exceptions = TRUE) {
19201933
procs <- CmdStanProcs$new(
@@ -1940,7 +1953,9 @@ pathfinder <- function(data = NULL,
19401953
num_paths = num_paths,
19411954
max_lbfgs_iters = max_lbfgs_iters,
19421955
num_elbo_draws = num_elbo_draws,
1943-
save_single_paths = save_single_paths
1956+
save_single_paths = save_single_paths,
1957+
psis_resample = psis_resample,
1958+
calculate_lp = calculate_lp
19441959
)
19451960
args <- CmdStanArgs$new(
19461961
method_args = pathfinder_args,

man/model-method-pathfinder.Rd

Lines changed: 13 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-model-pathfinder.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ ok_arg_values <- list(
3030
draws = 100,
3131
num_paths = 4,
3232
max_lbfgs_iters = 100,
33-
save_single_paths = FALSE)
33+
save_single_paths = FALSE,
34+
calculate_lp = TRUE,
35+
psis_resample=TRUE)
3436

3537
# using any one of these should cause sample() to error
3638
bad_arg_values <- list(

0 commit comments

Comments
 (0)