Skip to content

Commit 15aa9d9

Browse files
authored
Merge pull request #932 from venpopov/expose_new_stan_args
Expose new stan args
2 parents cc2e36d + 6d7ee0e commit 15aa9d9

19 files changed

+600
-277
lines changed

R/args.R

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ CmdStanArgs <- R6::R6Class(
4343
sig_figs = NULL,
4444
opencl_ids = NULL,
4545
model_variables = NULL,
46-
num_threads = NULL) {
46+
num_threads = NULL,
47+
save_cmdstan_config = NULL) {
4748

4849
self$model_name <- model_name
4950
self$stan_code <- stan_code
@@ -60,6 +61,7 @@ CmdStanArgs <- R6::R6Class(
6061
self$save_latent_dynamics <- save_latent_dynamics
6162
self$using_tempdir <- is.null(output_dir)
6263
self$model_variables <- model_variables
64+
self$save_cmdstan_config <- save_cmdstan_config
6365
if (os_is_wsl()) {
6466
# Want to ensure that any files under WSL are written to a tempdir within
6567
# WSL to avoid IO performance issues
@@ -87,6 +89,9 @@ CmdStanArgs <- R6::R6Class(
8789
self$opencl_ids <- opencl_ids
8890
self$num_threads = NULL
8991
self$method_args$validate(num_procs = length(self$proc_ids))
92+
if (is.logical(self$save_cmdstan_config)) {
93+
self$save_cmdstan_config <- as.integer(self$save_cmdstan_config)
94+
}
9095
self$validate()
9196
},
9297
validate = function() {
@@ -111,7 +116,7 @@ CmdStanArgs <- R6::R6Class(
111116
} else if (type == "profile") {
112117
basename <- paste0(basename, "-profile")
113118
}
114-
if (type == "output" && !is.null(self$output_basename)) {
119+
if (type == "output" && !is.null(self$output_basename)) {
115120
basename <- self$output_basename
116121
}
117122
generate_file_names(
@@ -180,6 +185,9 @@ CmdStanArgs <- R6::R6Class(
180185
if (!is.null(profile_file)) {
181186
args$output <- c(args$output, paste0("profile_file=", wsl_safe_path(profile_file)))
182187
}
188+
if (!is.null(self$save_cmdstan_config)) {
189+
args$output <- c(args$output, paste0("save_cmdstan_config=", self$save_cmdstan_config))
190+
}
183191
if (!is.null(self$opencl_ids)) {
184192
args$opencl <- c("opencl", paste0("platform=", self$opencl_ids[1]), paste0("device=", self$opencl_ids[2]))
185193
}
@@ -218,7 +226,8 @@ SampleArgs <- R6::R6Class(
218226
term_buffer = NULL,
219227
window = NULL,
220228
fixed_param = FALSE,
221-
diagnostics = NULL) {
229+
diagnostics = NULL,
230+
save_metric = NULL) {
222231

223232
self$iter_warmup <- iter_warmup
224233
self$iter_sampling <- iter_sampling
@@ -232,6 +241,7 @@ SampleArgs <- R6::R6Class(
232241
self$inv_metric <- inv_metric
233242
self$fixed_param <- fixed_param
234243
self$diagnostics <- diagnostics
244+
self$save_metric <- save_metric
235245
if (identical(self$diagnostics, "")) {
236246
self$diagnostics <- NULL
237247
}
@@ -275,6 +285,9 @@ SampleArgs <- R6::R6Class(
275285
if (is.logical(self$save_warmup)) {
276286
self$save_warmup <- as.integer(self$save_warmup)
277287
}
288+
if (is.logical(self$save_metric)) {
289+
self$save_metric <- as.integer(self$save_metric)
290+
}
278291
invisible(self)
279292
},
280293
validate = function(num_procs) {
@@ -314,7 +327,8 @@ SampleArgs <- R6::R6Class(
314327
.make_arg("adapt_engaged"),
315328
.make_arg("init_buffer"),
316329
.make_arg("term_buffer"),
317-
.make_arg("window")
330+
.make_arg("window"),
331+
.make_arg("save_metric")
318332
)
319333
} else {
320334
new_args <- list(
@@ -335,7 +349,8 @@ SampleArgs <- R6::R6Class(
335349
.make_arg("adapt_engaged"),
336350
.make_arg("init_buffer"),
337351
.make_arg("term_buffer"),
338-
.make_arg("window")
352+
.make_arg("window"),
353+
.make_arg("save_metric")
339354
)
340355
}
341356
new_args <- do.call(c, new_args)
@@ -682,6 +697,7 @@ validate_cmdstan_args <- function(self) {
682697
checkmate::assert_flag(self$save_latent_dynamics)
683698
checkmate::assert_integerish(self$refresh, lower = 0, null.ok = TRUE)
684699
checkmate::assert_integerish(self$sig_figs, lower = 1, upper = 18, null.ok = TRUE)
700+
checkmate::assert_integerish(self$save_cmdstan_config, lower = 0, upper = 1, len = 1, null.ok = TRUE)
685701
if (!is.null(self$sig_figs) && cmdstan_version() < "2.25") {
686702
warning("The 'sig_figs' argument is only supported with cmdstan 2.25+ and will be ignored!", call. = FALSE)
687703
}
@@ -799,6 +815,15 @@ validate_sample_args <- function(self, num_procs) {
799815
checkmate::assert_subset(self$diagnostics, empty.ok = FALSE, choices = available_hmc_diagnostics())
800816
}
801817

818+
checkmate::assert_integerish(self$save_metric,
819+
lower = 0, upper = 1,
820+
len = 1,
821+
null.ok = TRUE)
822+
823+
if (is.null(self$adapt_engaged) || (!self$adapt_engaged && !is.null(self$save_metric))) {
824+
self$save_metric <- 0
825+
}
826+
802827
invisible(TRUE)
803828
}
804829

R/fit.R

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -898,10 +898,13 @@ CmdStanFit$set("public", name = "cmdstan_diagnose", value = cmdstan_diagnose)
898898
#' Save output and data files
899899
#'
900900
#' @name fit-method-save_output_files
901-
#' @aliases fit-method-save_data_file fit-method-save_latent_dynamics_files fit-method-save_profile_files
902-
#' fit-method-output_files fit-method-data_file fit-method-latent_dynamics_files fit-method-profile_files
903-
#' save_output_files save_data_file save_latent_dynamics_files save_profile_files
904-
#' output_files data_file latent_dynamics_files profile_files
901+
#' @aliases fit-method-save_data_file fit-method-save_latent_dynamics_files
902+
#' fit-method-save_profile_files fit-method-output_files fit-method-data_file
903+
#' fit-method-latent_dynamics_files fit-method-profile_files
904+
#' fit-method-save_config_files fit-method-save_metric_files save_output_files
905+
#' save_data_file save_latent_dynamics_files save_profile_files
906+
#' save_config_files save_metric_files output_files data_file
907+
#' latent_dynamics_files profile_files config_files metric_files
905908
#'
906909
#' @description All fitted model objects have methods for saving (moving to a
907910
#' specified location) the files created by CmdStanR to hold CmdStan output
@@ -936,6 +939,14 @@ CmdStanFit$set("public", name = "cmdstan_diagnose", value = cmdstan_diagnose)
936939
#' `$save_output_files()` except `"-profile-"` is included in the new
937940
#' file name after `basename`.
938941
#'
942+
#' For `$save_metric_files()` everything is the same as for
943+
#' `$save_output_files()` except `"-metric-"` is included in the new
944+
#' file name after `basename`.
945+
#'
946+
#' For `$save_config_files()` everything is the same as for
947+
#' `$save_output_files()` except `"-config-"` is included in the new
948+
#' file name after `basename`.
949+
#'
939950
#' For `$save_data_file()` no `id` is included in the file name because even
940951
#' with multiple MCMC chains the data file is the same.
941952
#'
@@ -998,6 +1009,26 @@ save_data_file <- function(dir = ".",
9981009
}
9991010
CmdStanFit$set("public", name = "save_data_file", value = save_data_file)
10001011

1012+
#' @rdname fit-method-save_output_files
1013+
save_config_files <- function(dir = ".",
1014+
basename = NULL,
1015+
timestamp = TRUE,
1016+
random = TRUE) {
1017+
self$runset$save_config_files(dir, basename, timestamp, random)
1018+
}
1019+
CmdStanFit$set("public", name = "save_config_files", value = save_config_files)
1020+
1021+
#' @rdname fit-method-save_output_files
1022+
save_metric_files <- function(dir = ".",
1023+
basename = NULL,
1024+
timestamp = TRUE,
1025+
random = TRUE) {
1026+
self$runset$save_metric_files(dir, basename, timestamp, random)
1027+
}
1028+
CmdStanFit$set("public", name = "save_metric_files", value = save_metric_files)
1029+
1030+
1031+
10011032
#' @rdname fit-method-save_output_files
10021033
#' @param include_failed (logical) Should CmdStan runs that failed also be
10031034
#' included? The default is `FALSE.`
@@ -1024,6 +1055,17 @@ data_file <- function() {
10241055
}
10251056
CmdStanFit$set("public", name = "data_file", value = data_file)
10261057

1058+
#' @rdname fit-method-save_output_files
1059+
config_files <- function(include_failed = FALSE) {
1060+
self$runset$config_files(include_failed)
1061+
}
1062+
CmdStanFit$set("public", name = "config_files", value = config_files)
1063+
1064+
#' @rdname fit-method-save_output_files
1065+
metric_files <- function(include_failed = FALSE) {
1066+
self$runset$metric_files(include_failed)
1067+
}
1068+
CmdStanFit$set("public", name = "metric_files", value = metric_files)
10271069

10281070
#' Report timing of CmdStan runs
10291071
#'

R/model.R

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,6 +1149,8 @@ sample <- function(data = NULL,
11491149
show_messages = TRUE,
11501150
show_exceptions = TRUE,
11511151
diagnostics = c("divergences", "treedepth", "ebfmi"),
1152+
save_metric = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL },
1153+
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL },
11521154
# deprecated
11531155
cores = NULL,
11541156
num_cores = NULL,
@@ -1240,7 +1242,8 @@ sample <- function(data = NULL,
12401242
term_buffer = term_buffer,
12411243
window = window,
12421244
fixed_param = fixed_param,
1243-
diagnostics = diagnostics
1245+
diagnostics = diagnostics,
1246+
save_metric = save_metric
12441247
)
12451248
args <- CmdStanArgs$new(
12461249
method_args = sample_args,
@@ -1260,7 +1263,8 @@ sample <- function(data = NULL,
12601263
output_basename = output_basename,
12611264
sig_figs = sig_figs,
12621265
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
1263-
model_variables = model_variables
1266+
model_variables = model_variables,
1267+
save_cmdstan_config = save_cmdstan_config
12641268
)
12651269
runset <- CmdStanRun$new(args, procs)
12661270
runset$run_cmdstan()
@@ -1357,6 +1361,7 @@ sample_mpi <- function(data = NULL,
13571361
show_messages = TRUE,
13581362
show_exceptions = TRUE,
13591363
diagnostics = c("divergences", "treedepth", "ebfmi"),
1364+
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL },
13601365
# deprecated
13611366
validate_csv = TRUE) {
13621367

@@ -1420,7 +1425,8 @@ sample_mpi <- function(data = NULL,
14201425
output_dir = output_dir,
14211426
output_basename = output_basename,
14221427
sig_figs = sig_figs,
1423-
model_variables = model_variables
1428+
model_variables = model_variables,
1429+
save_cmdstan_config = save_cmdstan_config
14241430
)
14251431
runset <- CmdStanRun$new(args, procs)
14261432
runset$run_cmdstan_mpi(mpi_cmd, mpi_args)
@@ -1500,7 +1506,8 @@ optimize <- function(data = NULL,
15001506
tol_param = NULL,
15011507
history_size = NULL,
15021508
show_messages = TRUE,
1503-
show_exceptions = TRUE) {
1509+
show_exceptions = TRUE,
1510+
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
15041511
procs <- CmdStanProcs$new(
15051512
num_procs = 1,
15061513
show_stderr_messages = show_exceptions,
@@ -1541,7 +1548,8 @@ optimize <- function(data = NULL,
15411548
output_basename = output_basename,
15421549
sig_figs = sig_figs,
15431550
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
1544-
model_variables = model_variables
1551+
model_variables = model_variables,
1552+
save_cmdstan_config = save_cmdstan_config
15451553
)
15461554
runset <- CmdStanRun$new(args, procs)
15471555
runset$run_cmdstan()
@@ -1632,7 +1640,8 @@ laplace <- function(data = NULL,
16321640
jacobian = TRUE, # different default than for optimize!
16331641
draws = NULL,
16341642
show_messages = TRUE,
1635-
show_exceptions = TRUE) {
1643+
show_exceptions = TRUE,
1644+
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
16361645
if (cmdstan_version() < "2.32") {
16371646
stop("This method is only available in cmdstan >= 2.32", call. = FALSE)
16381647
}
@@ -1706,7 +1715,8 @@ laplace <- function(data = NULL,
17061715
output_basename = output_basename,
17071716
sig_figs = sig_figs,
17081717
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
1709-
model_variables = model_variables
1718+
model_variables = model_variables,
1719+
save_cmdstan_config = save_cmdstan_config
17101720
)
17111721
runset <- CmdStanRun$new(args, procs)
17121722
runset$run_cmdstan()
@@ -1786,7 +1796,8 @@ variational <- function(data = NULL,
17861796
output_samples = NULL,
17871797
draws = NULL,
17881798
show_messages = TRUE,
1789-
show_exceptions = TRUE) {
1799+
show_exceptions = TRUE,
1800+
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
17901801
procs <- CmdStanProcs$new(
17911802
num_procs = 1,
17921803
show_stderr_messages = show_exceptions,
@@ -1827,7 +1838,8 @@ variational <- function(data = NULL,
18271838
output_basename = output_basename,
18281839
sig_figs = sig_figs,
18291840
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
1830-
model_variables = model_variables
1841+
model_variables = model_variables,
1842+
save_cmdstan_config = save_cmdstan_config
18311843
)
18321844
runset <- CmdStanRun$new(args, procs)
18331845
runset$run_cmdstan()
@@ -1929,7 +1941,8 @@ pathfinder <- function(data = NULL,
19291941
psis_resample = NULL,
19301942
calculate_lp = NULL,
19311943
show_messages = TRUE,
1932-
show_exceptions = TRUE) {
1944+
show_exceptions = TRUE,
1945+
save_cmdstan_config = if (cmdstan_version() > "2.34.0") { TRUE } else { NULL }) {
19331946
procs <- CmdStanProcs$new(
19341947
num_procs = 1,
19351948
show_stderr_messages = show_exceptions,
@@ -1976,7 +1989,8 @@ pathfinder <- function(data = NULL,
19761989
sig_figs = sig_figs,
19771990
opencl_ids = assert_valid_opencl(opencl_ids, self$cpp_options()),
19781991
model_variables = model_variables,
1979-
num_threads = num_threads
1992+
num_threads = num_threads,
1993+
save_cmdstan_config = save_cmdstan_config
19801994
)
19811995
runset <- CmdStanRun$new(args, procs)
19821996
runset$run_cmdstan()

0 commit comments

Comments
 (0)