Skip to content

Commit 8a9ae96

Browse files
authored
Merge 5c2dbbd into 17678d5
2 parents 17678d5 + 5c2dbbd commit 8a9ae96

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+2240
-138
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Generated by roxygen2: do not edit by hand
22

33
S3method(as_draws,CmdStanGQ)
4+
S3method(as_draws,CmdStanLaplace)
45
S3method(as_draws,CmdStanMCMC)
56
S3method(as_draws,CmdStanMLE)
67
S3method(as_draws,CmdStanVB)

R/args.R

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#'
1515
#' * `SampleArgs`: stores arguments specific to `method=sample`.
1616
#' * `OptimizeArgs`: stores arguments specific to `method=optimize`.
17+
#' * `LaplaceArgs`: stores arguments specific to `method=laplace`.
1718
#' * `VariationalArgs`: stores arguments specific to `method=variational`
1819
#' * `GenerateQuantitiesArgs`: stores arguments specific to `method=generate_quantities`
1920
#' * `DiagnoseArgs`: stores arguments specific to `method=diagnose`
@@ -427,6 +428,49 @@ OptimizeArgs <- R6::R6Class(
427428
)
428429

429430

431+
# LaplaceArgs -------------------------------------------------------------
432+
433+
LaplaceArgs <- R6::R6Class(
434+
"LaplaceArgs",
435+
lock_objects = FALSE,
436+
public = list(
437+
method = "laplace",
438+
initialize = function(mode = NULL,
439+
draws = NULL,
440+
jacobian = TRUE) {
441+
checkmate::assert_r6(mode, classes = "CmdStanMLE")
442+
self$mode_object <- mode # keep the CmdStanMLE for later use (can be returned by CmdStanLaplace$mode())
443+
self$mode <- self$mode_object$output_files() # mode <- file path to pass to CmdStan
444+
self$jacobian <- jacobian
445+
self$draws <- draws
446+
invisible(self)
447+
},
448+
validate = function(num_procs) {
449+
validate_laplace_args(self)
450+
invisible(self)
451+
},
452+
453+
# Compose arguments to CmdStan command for laplace-specific
454+
# non-default arguments. Works the same way as compose for sampler args,
455+
# but `idx` is ignored (no multiple chains for optimize or variational)
456+
compose = function(idx = NULL, args = NULL) {
457+
.make_arg <- function(arg_name) {
458+
compose_arg(self, arg_name, idx = NULL)
459+
}
460+
new_args <- list(
461+
"method=laplace",
462+
.make_arg("mode"),
463+
.make_arg("draws"),
464+
.make_arg("jacobian")
465+
)
466+
new_args <- do.call(c, new_args)
467+
c(args, new_args)
468+
}
469+
)
470+
)
471+
472+
473+
430474
# VariationalArgs ---------------------------------------------------------
431475

432476
VariationalArgs <- R6::R6Class(
@@ -712,6 +756,29 @@ validate_optimize_args <- function(self) {
712756
invisible(TRUE)
713757
}
714758

759+
#' Validate arguments for laplace
760+
#' @noRd
761+
#' @param self A `LaplaceArgs` object.
762+
#' @return `TRUE` invisibly unless an error is thrown.
763+
validate_laplace_args <- function(self) {
764+
checkmate::assert_file_exists(self$mode, extension = "csv")
765+
checkmate::assert_integerish(self$draws, lower = 1, null.ok = TRUE, len = 1)
766+
if (!is.null(self$draws)) {
767+
self$draws <- as.integer(self$draws)
768+
}
769+
checkmate::assert_flag(self$jacobian, null.ok = FALSE)
770+
if (self$mode_object$metadata()$jacobian != self$jacobian) {
771+
stop(
772+
"'jacobian' argument to optimize and laplace must match!\n",
773+
"laplace was called with jacobian=", self$jacobian, "\n",
774+
"optimize was run with jacobian=", as.logical(self$mode_object$metadata()$jacobian),
775+
call. = FALSE
776+
)
777+
}
778+
self$jacobian <- as.integer(self$jacobian)
779+
invisible(TRUE)
780+
}
781+
715782
#' Validate arguments for standalone generated quantities
716783
#' @noRd
717784
#' @param self A `GenerateQuantitiesArgs` object.
@@ -764,7 +831,7 @@ validate_variational_args <- function(self) {
764831
self$eval_elbo <- as.integer(self$eval_elbo)
765832
}
766833
checkmate::assert_integerish(self$output_samples, null.ok = TRUE,
767-
lower = 1, len = 1)
834+
lower = 1, len = 1, .var.name = "draws")
768835
if (!is.null(self$output_samples)) {
769836
self$output_samples <- as.integer(self$output_samples)
770837
}

R/csv.R

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
#' and memory for models with many parameters.
2828
#'
2929
#' @return
30-
#' `as_cmdstan_fit()` returns a [CmdStanMCMC], [CmdStanMLE], or
30+
#' `as_cmdstan_fit()` returns a [CmdStanMCMC], [CmdStanMLE], [CmdStanLaplace] or
3131
#' [CmdStanVB] object. Some methods typically defined for those objects will not
3232
#' work (e.g. `save_data_file()`) but the important methods like `$summary()`,
3333
#' `$draws()`, `$sampler_diagnostics()` and others will work fine.
@@ -67,7 +67,8 @@
6767
#'
6868
#' * `point_estimates`: Point estimates for the model parameters.
6969
#'
70-
#' For [variational inference][model-method-variational] the returned list also
70+
#' For [laplace][model-method-laplace] and
71+
#' [variational inference][model-method-variational] the returned list also
7172
#' includes the following components:
7273
#'
7374
#' * `draws`: A [`draws_matrix`][posterior::draws_matrix] (or different format
@@ -307,6 +308,11 @@ read_cmdstan_csv <- function(files,
307308
repaired_variables <- repaired_variables[repaired_variables != "lp__"]
308309
repaired_variables <- gsub("log_p__", "lp__", repaired_variables)
309310
repaired_variables <- gsub("log_g__", "lp_approx__", repaired_variables)
311+
} else if (metadata$method == "laplace") {
312+
metadata$variables <- gsub("log_p__", "lp__", metadata$variables)
313+
metadata$variables <- gsub("log_q__", "lp_approx__", metadata$variables)
314+
repaired_variables <- gsub("log_p__", "lp__", repaired_variables)
315+
repaired_variables <- gsub("log_q__", "lp_approx__", repaired_variables)
310316
}
311317
model_param_dims <- variable_dims(metadata$variables)
312318
metadata$stan_variable_sizes <- model_param_dims
@@ -385,6 +391,29 @@ read_cmdstan_csv <- function(files,
385391
metadata = metadata,
386392
draws = variational_draws
387393
)
394+
} else if (metadata$method == "laplace") {
395+
if (is.null(format)) {
396+
format <- "draws_matrix"
397+
}
398+
as_draws_format <- as_draws_format_fun(format)
399+
if (length(draws) == 0) {
400+
laplace_draws <- NULL
401+
} else {
402+
laplace_draws <- do.call(as_draws_format, list(draws[[1]]))
403+
}
404+
if (!is.null(laplace_draws)) {
405+
if ("log_p__" %in% posterior::variables(laplace_draws)) {
406+
laplace_draws <- posterior::rename_variables(laplace_draws, lp__ = "log_p__")
407+
}
408+
if ("log_q__" %in% posterior::variables(laplace_draws)) {
409+
laplace_draws <- posterior::rename_variables(laplace_draws, lp_approx__ = "log_q__")
410+
}
411+
posterior::variables(laplace_draws) <- repaired_variables
412+
}
413+
list(
414+
metadata = metadata,
415+
draws = laplace_draws
416+
)
388417
} else if (metadata$method == "optimize") {
389418
if (is.null(format)) {
390419
format <- "draws_matrix"
@@ -447,7 +476,8 @@ as_cmdstan_fit <- function(files, check_diagnostics = TRUE, format = getOption("
447476
csv_contents$metadata$method,
448477
"sample" = CmdStanMCMC_CSV$new(csv_contents, files, check_diagnostics),
449478
"optimize" = CmdStanMLE_CSV$new(csv_contents, files),
450-
"variational" = CmdStanVB_CSV$new(csv_contents, files)
479+
"variational" = CmdStanVB_CSV$new(csv_contents, files),
480+
"laplace" = CmdStanLaplace_CSV$new(csv_contents, files)
451481
)
452482
}
453483

@@ -513,6 +543,22 @@ CmdStanMLE_CSV <- R6::R6Class(
513543
),
514544
private = list(output_files_ = NULL)
515545
)
546+
CmdStanLaplace_CSV <- R6::R6Class(
547+
classname = "CmdStanLaplace_CSV",
548+
inherit = CmdStanLaplace,
549+
public = list(
550+
initialize = function(csv_contents, files) {
551+
private$output_files_ <- files
552+
private$draws_ <- csv_contents$draws
553+
private$metadata_ <- csv_contents$metadata
554+
invisible(self)
555+
},
556+
output_files = function(...) {
557+
private$output_files_
558+
}
559+
),
560+
private = list(output_files_ = NULL)
561+
)
516562
CmdStanVB_CSV <- R6::R6Class(
517563
classname = "CmdStanVB_CSV",
518564
inherit = CmdStanVB,
@@ -554,6 +600,7 @@ for (method in unavailable_methods_CmdStanFit_CSV) {
554600
}
555601
CmdStanMLE_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
556602
CmdStanVB_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
603+
CmdStanLaplace_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
557604
}
558605

559606

@@ -616,7 +663,7 @@ read_csv_metadata <- function(csv_file) {
616663
all_names <- strsplit(line, ",")[[1]]
617664
if (all(csv_file_info$algorithm != "fixed_param")) {
618665
csv_file_info[["sampler_diagnostics"]] <- all_names[endsWith(all_names, "__")]
619-
csv_file_info[["sampler_diagnostics"]] <- csv_file_info[["sampler_diagnostics"]][!(csv_file_info[["sampler_diagnostics"]] %in% c("lp__", "log_p__", "log_g__"))]
666+
csv_file_info[["sampler_diagnostics"]] <- csv_file_info[["sampler_diagnostics"]][!(csv_file_info[["sampler_diagnostics"]] %in% c("lp__", "log_p__", "log_g__", "log_q__"))]
620667
csv_file_info[["variables"]] <- all_names[!(all_names %in% csv_file_info[["sampler_diagnostics"]])]
621668
} else {
622669
csv_file_info[["variables"]] <- all_names[!endsWith(all_names, "__")]
@@ -719,7 +766,7 @@ read_csv_metadata <- function(csv_file) {
719766
csv_file_info$step_size <- csv_file_info$stepsize
720767
csv_file_info$iter_warmup <- csv_file_info$num_warmup
721768
csv_file_info$iter_sampling <- csv_file_info$num_samples
722-
if (csv_file_info$method == "variational" || csv_file_info$method == "optimize") {
769+
if (csv_file_info$method %in% c("variational", "optimize", "laplace")) {
723770
csv_file_info$threads <- csv_file_info$num_threads
724771
} else {
725772
csv_file_info$threads_per_chain <- csv_file_info$num_threads

R/example.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
#'
5151
cmdstanr_example <-
5252
function(example = c("logistic", "schools", "schools_ncp"),
53-
method = c("sample", "optimize", "variational", "diagnose"),
53+
method = c("sample", "optimize", "laplace", "variational", "diagnose"),
5454
...,
5555
quiet = TRUE,
5656
force_recompile = getOption("cmdstanr_force_recompile", default = FALSE)) {

0 commit comments

Comments
 (0)