diff --git a/.github/workflows/Test-coverage.yaml b/.github/workflows/Test-coverage.yaml index b16f1997e..a2e2bbfc0 100644 --- a/.github/workflows/Test-coverage.yaml +++ b/.github/workflows/Test-coverage.yaml @@ -80,7 +80,7 @@ jobs: shell: Rscript {0} - name: Test coverage - run: covr::codecov(type = "all") + run: covr::codecov(type = "tests") shell: Rscript {0} test-coverage-windows: @@ -131,5 +131,5 @@ jobs: - name: Test coverage run: | options(covr.gcov = 'C:/rtools40/mingw64/bin/gcov.exe'); - covr::codecov(type = "all", function_exclusions = "sample_mpi") + covr::codecov(type = "tests", function_exclusions = "sample_mpi") shell: Rscript {0} diff --git a/R/args.R b/R/args.R index 7cc8590ea..2441e92a8 100644 --- a/R/args.R +++ b/R/args.R @@ -32,7 +32,8 @@ CmdStanArgs <- R6::R6Class( output_dir = NULL, output_basename = NULL, validate_csv = TRUE, - sig_figs = NULL) { + sig_figs = NULL, + opencl_ids = NULL) { self$model_name <- model_name self$exe_file <- exe_file @@ -59,7 +60,7 @@ CmdStanArgs <- R6::R6Class( init <- process_init_list(init, length(self$proc_ids)) } self$init <- init - + self$opencl_ids <- opencl_ids self$method_args$validate(num_procs = length(self$proc_ids)) self$validate() }, @@ -154,7 +155,9 @@ CmdStanArgs <- R6::R6Class( if (!is.null(profile_file)) { args$output <- c(args$output, paste0("profile_file=", profile_file)) } - + if (!is.null(self$opencl_ids)) { + args$opencl <- c("opencl", paste0("platform=", self$opencl_ids[1]), paste0("device=", self$opencl_ids[2])) + } args <- do.call(c, append(args, list(use.names = FALSE))) self$method_args$compose(idx, args) }, @@ -496,7 +499,12 @@ validate_cmdstan_args = function(self) { num_procs <- length(self$proc_ids) validate_init(self$init, num_procs) validate_seed(self$seed, num_procs) - + if (!is.null(self$opencl_ids)) { + if (cmdstan_version() < "2.25") { + stop("Runtime selection of OpenCL devices is only supported with CmdStan version 2.26 or newer.", call. = FALSE) + } + checkmate::assert_vector(self$opencl_ids, len = 2) + } invisible(TRUE) } diff --git a/R/model.R b/R/model.R index 61c85a327..98df85116 100644 --- a/R/model.R +++ b/R/model.R @@ -698,6 +698,7 @@ sample <- function(data = NULL, parallel_chains = getOption("mc.cores", 1), chain_ids = seq_len(chains), threads_per_chain = NULL, + opencl_ids = NULL, iter_warmup = NULL, iter_sampling = NULL, save_warmup = FALSE, @@ -782,6 +783,7 @@ sample <- function(data = NULL, call. = FALSE) } } + check_opencl(self$cpp_options(), opencl_ids) sample_args <- SampleArgs$new( iter_warmup = iter_warmup, iter_sampling = iter_sampling, @@ -812,7 +814,8 @@ sample <- function(data = NULL, output_dir = output_dir, output_basename = output_basename, sig_figs = sig_figs, - validate_csv = validate_csv + validate_csv = validate_csv, + opencl_ids = opencl_ids ) cmdstan_procs <- CmdStanMCMCProcs$new( num_procs = chains, @@ -1021,6 +1024,7 @@ optimize <- function(data = NULL, output_basename = NULL, sig_figs = NULL, threads = NULL, + opencl_ids = NULL, algorithm = NULL, init_alpha = NULL, iter = NULL, @@ -1045,6 +1049,7 @@ optimize <- function(data = NULL, call. = FALSE) } } + check_opencl(self$cpp_options(), opencl_ids) optimize_args <- OptimizeArgs$new( algorithm = algorithm, init_alpha = init_alpha, @@ -1068,7 +1073,8 @@ optimize <- function(data = NULL, refresh = refresh, output_dir = output_dir, output_basename = output_basename, - sig_figs = sig_figs + sig_figs = sig_figs, + opencl_ids = opencl_ids ) cmdstan_procs <- CmdStanProcs$new( @@ -1143,6 +1149,7 @@ variational <- function(data = NULL, output_basename = NULL, sig_figs = NULL, threads = NULL, + opencl_ids = NULL, algorithm = NULL, iter = NULL, grad_samples = NULL, @@ -1168,6 +1175,7 @@ variational <- function(data = NULL, call. = FALSE) } } + check_opencl(self$cpp_options(), opencl_ids) variational_args <- VariationalArgs$new( algorithm = algorithm, iter = iter, @@ -1192,7 +1200,8 @@ variational <- function(data = NULL, refresh = refresh, output_dir = output_dir, output_basename = output_basename, - sig_figs = sig_figs + sig_figs = sig_figs, + opencl_ids = opencl_ids ) cmdstan_procs <- CmdStanProcs$new( @@ -1278,7 +1287,8 @@ generate_quantities <- function(fitted_params, output_basename = NULL, sig_figs = NULL, parallel_chains = getOption("mc.cores", 1), - threads_per_chain = NULL) { + threads_per_chain = NULL, + opencl_ids = NULL) { checkmate::assert_integerish(parallel_chains, lower = 1, null.ok = TRUE) checkmate::assert_integerish(threads_per_chain, lower = 1, len = 1, null.ok = TRUE) if (is.null(self$cpp_options()[["stan_threads"]])) { @@ -1295,7 +1305,7 @@ generate_quantities <- function(fitted_params, call. = FALSE) } } - + check_opencl(self$cpp_options(), opencl_ids) fitted_params <- process_fitted_params(fitted_params) chains <- length(fitted_params) generate_quantities_args <- GenerateQuantitiesArgs$new( @@ -1310,7 +1320,8 @@ generate_quantities <- function(fitted_params, seed = seed, output_dir = output_dir, output_basename = output_basename, - sig_figs = sig_figs + sig_figs = sig_figs, + opencl_ids = opencl_ids ) cmdstan_procs <- CmdStanGQProcs$new( num_procs = chains, @@ -1322,3 +1333,13 @@ generate_quantities <- function(fitted_params, CmdStanGQ$new(runset) } CmdStanModel$set("public", name = "generate_quantities", value = generate_quantities) + + +check_opencl <- function(cpp_options, opencl_ids) { + if (is.null(cpp_options[["stan_opencl"]]) + && !is.null(opencl_ids)) { + stop("'opencl_ids' is set but the model was not compiled with for use with OpenCL.", + "\nRecompile the model with the 'cpp_options = list(stan_opencl = TRUE)'", + call. = FALSE) + } +} \ No newline at end of file diff --git a/R/run.R b/R/run.R index 7f5d7a764..65565abd4 100644 --- a/R/run.R +++ b/R/run.R @@ -824,9 +824,6 @@ CmdStanMCMCProcs <- R6::R6Class( state <- 1.5 next_state <- 1.5 } - if (state < 3 && grepl("profile_file =", line, perl = TRUE)) { - next_state <- 3 - } if (state <= 3 && grepl("Rejecting initial value:", line, perl = TRUE)) { state <- 2 next_state <- 2 diff --git a/_pkgdown.yml b/_pkgdown.yml index 6e1bb2b54..83cdb38b3 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -69,8 +69,9 @@ articles: profiling Stan programs, and using CmdStanR in R Markdown documents. contents: - cmdstanr-internals - - profiling - r-markdown + - profiling + - opencl reference: - title: "Package description" diff --git a/docs/articles/index.html b/docs/articles/index.html index e98e111da..20c814b3e 100644 --- a/docs/articles/index.html +++ b/docs/articles/index.html @@ -87,7 +87,7 @@