|
27 | 27 | #' and memory for models with many parameters.
|
28 | 28 | #'
|
29 | 29 | #' @return
|
30 |
| -#' `as_cmdstan_fit()` returns a [CmdStanMCMC], [CmdStanMLE], or |
| 30 | +#' `as_cmdstan_fit()` returns a [CmdStanMCMC], [CmdStanMLE], [CmdStanLaplace] or |
31 | 31 | #' [CmdStanVB] object. Some methods typically defined for those objects will not
|
32 | 32 | #' work (e.g. `save_data_file()`) but the important methods like `$summary()`,
|
33 | 33 | #' `$draws()`, `$sampler_diagnostics()` and others will work fine.
|
|
67 | 67 | #'
|
68 | 68 | #' * `point_estimates`: Point estimates for the model parameters.
|
69 | 69 | #'
|
70 |
| -#' For [variational inference][model-method-variational] the returned list also |
| 70 | +#' For [laplace][model-method-laplace] and |
| 71 | +#' [variational inference][model-method-variational] the returned list also |
71 | 72 | #' includes the following components:
|
72 | 73 | #'
|
73 | 74 | #' * `draws`: A [`draws_matrix`][posterior::draws_matrix] (or different format
|
@@ -307,6 +308,11 @@ read_cmdstan_csv <- function(files,
|
307 | 308 | repaired_variables <- repaired_variables[repaired_variables != "lp__"]
|
308 | 309 | repaired_variables <- gsub("log_p__", "lp__", repaired_variables)
|
309 | 310 | repaired_variables <- gsub("log_g__", "lp_approx__", repaired_variables)
|
| 311 | + } else if (metadata$method == "laplace") { |
| 312 | + metadata$variables <- gsub("log_p__", "lp__", metadata$variables) |
| 313 | + metadata$variables <- gsub("log_q__", "lp_approx__", metadata$variables) |
| 314 | + repaired_variables <- gsub("log_p__", "lp__", repaired_variables) |
| 315 | + repaired_variables <- gsub("log_q__", "lp_approx__", repaired_variables) |
310 | 316 | }
|
311 | 317 | model_param_dims <- variable_dims(metadata$variables)
|
312 | 318 | metadata$stan_variable_sizes <- model_param_dims
|
@@ -385,6 +391,29 @@ read_cmdstan_csv <- function(files,
|
385 | 391 | metadata = metadata,
|
386 | 392 | draws = variational_draws
|
387 | 393 | )
|
| 394 | + } else if (metadata$method == "laplace") { |
| 395 | + if (is.null(format)) { |
| 396 | + format <- "draws_matrix" |
| 397 | + } |
| 398 | + as_draws_format <- as_draws_format_fun(format) |
| 399 | + if (length(draws) == 0) { |
| 400 | + laplace_draws <- NULL |
| 401 | + } else { |
| 402 | + laplace_draws <- do.call(as_draws_format, list(draws[[1]])) |
| 403 | + } |
| 404 | + if (!is.null(laplace_draws)) { |
| 405 | + if ("log_p__" %in% posterior::variables(laplace_draws)) { |
| 406 | + laplace_draws <- posterior::rename_variables(laplace_draws, lp__ = "log_p__") |
| 407 | + } |
| 408 | + if ("log_q__" %in% posterior::variables(laplace_draws)) { |
| 409 | + laplace_draws <- posterior::rename_variables(laplace_draws, lp_approx__ = "log_q__") |
| 410 | + } |
| 411 | + posterior::variables(laplace_draws) <- repaired_variables |
| 412 | + } |
| 413 | + list( |
| 414 | + metadata = metadata, |
| 415 | + draws = laplace_draws |
| 416 | + ) |
388 | 417 | } else if (metadata$method == "optimize") {
|
389 | 418 | if (is.null(format)) {
|
390 | 419 | format <- "draws_matrix"
|
@@ -447,7 +476,8 @@ as_cmdstan_fit <- function(files, check_diagnostics = TRUE, format = getOption("
|
447 | 476 | csv_contents$metadata$method,
|
448 | 477 | "sample" = CmdStanMCMC_CSV$new(csv_contents, files, check_diagnostics),
|
449 | 478 | "optimize" = CmdStanMLE_CSV$new(csv_contents, files),
|
450 |
| - "variational" = CmdStanVB_CSV$new(csv_contents, files) |
| 479 | + "variational" = CmdStanVB_CSV$new(csv_contents, files), |
| 480 | + "laplace" = CmdStanLaplace_CSV$new(csv_contents, files) |
451 | 481 | )
|
452 | 482 | }
|
453 | 483 |
|
@@ -513,6 +543,22 @@ CmdStanMLE_CSV <- R6::R6Class(
|
513 | 543 | ),
|
514 | 544 | private = list(output_files_ = NULL)
|
515 | 545 | )
|
| 546 | +CmdStanLaplace_CSV <- R6::R6Class( |
| 547 | + classname = "CmdStanLaplace_CSV", |
| 548 | + inherit = CmdStanLaplace, |
| 549 | + public = list( |
| 550 | + initialize = function(csv_contents, files) { |
| 551 | + private$output_files_ <- files |
| 552 | + private$draws_ <- csv_contents$draws |
| 553 | + private$metadata_ <- csv_contents$metadata |
| 554 | + invisible(self) |
| 555 | + }, |
| 556 | + output_files = function(...) { |
| 557 | + private$output_files_ |
| 558 | + } |
| 559 | + ), |
| 560 | + private = list(output_files_ = NULL) |
| 561 | +) |
516 | 562 | CmdStanVB_CSV <- R6::R6Class(
|
517 | 563 | classname = "CmdStanVB_CSV",
|
518 | 564 | inherit = CmdStanVB,
|
@@ -554,6 +600,7 @@ for (method in unavailable_methods_CmdStanFit_CSV) {
|
554 | 600 | }
|
555 | 601 | CmdStanMLE_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
|
556 | 602 | CmdStanVB_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV)
|
| 603 | + CmdStanLaplace_CSV$set("public", name = method, value = error_unavailable_CmdStanFit_CSV) |
557 | 604 | }
|
558 | 605 |
|
559 | 606 |
|
@@ -616,7 +663,7 @@ read_csv_metadata <- function(csv_file) {
|
616 | 663 | all_names <- strsplit(line, ",")[[1]]
|
617 | 664 | if (all(csv_file_info$algorithm != "fixed_param")) {
|
618 | 665 | csv_file_info[["sampler_diagnostics"]] <- all_names[endsWith(all_names, "__")]
|
619 |
| - csv_file_info[["sampler_diagnostics"]] <- csv_file_info[["sampler_diagnostics"]][!(csv_file_info[["sampler_diagnostics"]] %in% c("lp__", "log_p__", "log_g__"))] |
| 666 | + csv_file_info[["sampler_diagnostics"]] <- csv_file_info[["sampler_diagnostics"]][!(csv_file_info[["sampler_diagnostics"]] %in% c("lp__", "log_p__", "log_g__", "log_q__"))] |
620 | 667 | csv_file_info[["variables"]] <- all_names[!(all_names %in% csv_file_info[["sampler_diagnostics"]])]
|
621 | 668 | } else {
|
622 | 669 | csv_file_info[["variables"]] <- all_names[!endsWith(all_names, "__")]
|
@@ -719,7 +766,7 @@ read_csv_metadata <- function(csv_file) {
|
719 | 766 | csv_file_info$step_size <- csv_file_info$stepsize
|
720 | 767 | csv_file_info$iter_warmup <- csv_file_info$num_warmup
|
721 | 768 | csv_file_info$iter_sampling <- csv_file_info$num_samples
|
722 |
| - if (csv_file_info$method == "variational" || csv_file_info$method == "optimize") { |
| 769 | + if (csv_file_info$method %in% c("variational", "optimize", "laplace")) { |
723 | 770 | csv_file_info$threads <- csv_file_info$num_threads
|
724 | 771 | } else {
|
725 | 772 | csv_file_info$threads_per_chain <- csv_file_info$num_threads
|
|
0 commit comments