diff --git a/NAMESPACE b/NAMESPACE index 12a832d24..c1dd123dc 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -6,6 +6,7 @@ S3method(as_draws,CmdStanMLE) S3method(as_draws,CmdStanVB) export(as_cmdstan_fit) export(as_draws) +export(as_mcmc.list) export(check_cmdstan_toolchain) export(cmdstan_default_install_path) export(cmdstan_default_path) diff --git a/NEWS.md b/NEWS.md index f116da217..d681ced57 100644 --- a/NEWS.md +++ b/NEWS.md @@ -60,6 +60,9 @@ recompilation of Stan models. (#580) * New methods for `posterior::as_draws()` for CmdStanR fitted model objects. These are just wrappers around the `$draws()` method provided for convenience. (#532) +* New function `as_mcmc.list()` for converting CmdStanMCMC objects to mcmc.list +objects from the coda package. (#584, @MatsuuraKentaro) + # cmdstanr 0.4.0 ### Bug fixes diff --git a/R/utils.R b/R/utils.R index 4da37d40a..76d91f935 100644 --- a/R/utils.R +++ b/R/utils.R @@ -329,3 +329,46 @@ maybe_convert_draws_format <- function(draws, format) { stop("Invalid draws format.", call. = FALSE) ) } + + +# convert draws for external packages ------------------------------------------ + +#' Convert `CmdStanMCMC` to `mcmc.list` +#' +#' This function converts a `CmdStanMCMC` object to an `mcmc.list` object +#' compatible with the \pkg{coda} package. This is primarily intended for users +#' of Stan coming from BUGS/JAGS who are used to \pkg{coda} for plotting and +#' diagnostics. In general we recommend the more recent MCMC diagnostics in +#' \pkg{posterior} and the \pkg{ggplot2}-based plotting functions in +#' \pkg{bayesplot}, but for users who prefer \pkg{coda} this function provides +#' compatibility. +#' +#' @export +#' @param x A [CmdStanMCMC] object. +#' @return An `mcmc.list` object compatible with the \pkg{coda} package. +#' @examples +#' \dontrun{ +#' fit <- cmdstanr_example() +#' x <- as_mcmc.list(fit) +#' } +#' +as_mcmc.list <- function(x) { + if (!inherits(x, "CmdStanMCMC")) { + stop("Currently only CmdStanMCMC objects can be converted to mcmc.list.", + call. = FALSE) + } + sample_array <- x$draws(format = "array") + n_chain <- posterior::nchains(sample_array) + n_iteration <- posterior::niterations(sample_array) + class(sample_array) <- 'array' + mcmc_list <- lapply(seq_len(n_chain), function(chain) { + x <- sample_array[, chain, ] + dimnames(x) <- list(iteration = dimnames(sample_array)$iteration, + variable = dimnames(sample_array)$variable) + attr(x, 'mcpar') <- c(1, n_iteration, 1) + class(x) <- 'mcmc' + x + }) + class(mcmc_list) <- 'mcmc.list' + return(mcmc_list) +} diff --git a/man/as_mcmc.list.Rd b/man/as_mcmc.list.Rd new file mode 100644 index 000000000..ce994cf09 --- /dev/null +++ b/man/as_mcmc.list.Rd @@ -0,0 +1,30 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{as_mcmc.list} +\alias{as_mcmc.list} +\title{Convert \code{CmdStanMCMC} to \code{mcmc.list}} +\usage{ +as_mcmc.list(x) +} +\arguments{ +\item{x}{A \link{CmdStanMCMC} object.} +} +\value{ +An \code{mcmc.list} object compatible with the \pkg{coda} package. +} +\description{ +This function converts a \code{CmdStanMCMC} object to an \code{mcmc.list} object +compatible with the \pkg{coda} package. This is primarily intended for users +of Stan coming from BUGS/JAGS who are used to \pkg{coda} for plotting and +diagnostics. In general we recommend the more recent MCMC diagnostics in +\pkg{posterior} and the \pkg{ggplot2}-based plotting functions in +\pkg{bayesplot}, but for users who prefer \pkg{coda} this function provides +compatibility. +} +\examples{ +\dontrun{ +fit <- cmdstanr_example() +x <- as_mcmc.list(fit) +} + +} diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 1699c6b45..cb52be2ec 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -4,6 +4,7 @@ if (not_on_cran()) { set_cmdstan_path() fit_mcmc <- testing_fit("logistic", method = "sample", seed = 123, chains = 2) + fit_mle <- testing_fit("logistic", method = "opt", seed = 123) } test_that("check_divergences() works", { @@ -178,3 +179,20 @@ test_that("require_suggested_package() works", { "Please install the 'not_a_real_package' package to use this function." ) }) + +test_that("as_mcmc.list() works", { + x <- as_mcmc.list(fit_mcmc) + expect_length(x, fit_mcmc$num_chains()) + expect_s3_class(x, "mcmc.list") + expect_s3_class(x[[1]], "mcmc") + + draws <- fit_mcmc$draws() + x1 <- x[[1]] + expect_equal(dim(x1), c(posterior::niterations(draws), posterior::nvariables(draws))) + expect_equal(dimnames(x1)$variable, posterior::variables(draws)) + + expect_error( + as_mcmc.list(fit_mle), + "Currently only CmdStanMCMC objects can be converted to mcmc.list" + ) +})