Skip to content

Add OpenCL device selection at runtime and a OpenCL vignette #439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Apr 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d694881
add opencl device selection at runtime and a vignette
rok-cesnovar Jan 29, 2021
fd9e4ca
Merge branch 'master' into opencl_runtime_args
rok-cesnovar Feb 12, 2021
a09560d
add docs
rok-cesnovar Feb 12, 2021
dc9d282
Merge branch 'master' into opencl_runtime_args
rok-cesnovar Feb 25, 2021
bdfba29
Merge branch 'master' into opencl_runtime_args
rok-cesnovar Mar 5, 2021
b1f100b
a few additional comments
rok-cesnovar Mar 5, 2021
e9f9955
Merge branch 'master' into opencl_runtime_args
rok-cesnovar Mar 5, 2021
91d445a
Merge branch 'master' into opencl_runtime_args
rok-cesnovar Mar 15, 2021
9326573
Edits
bbbales2 Mar 16, 2021
95c32e0
edits and change to opencl_ids
rok-cesnovar Mar 16, 2021
2102434
syntax fix
rok-cesnovar Mar 16, 2021
44d71c9
Merge branch 'master' into opencl_runtime_args
rok-cesnovar Mar 17, 2021
e537e06
fix typo
rok-cesnovar Mar 17, 2021
976627d
Merge remote-tracking branch 'origin/master' into opencl_runtime_args
rok-cesnovar Mar 17, 2021
8be26a6
Apply suggestions from code review
rok-cesnovar Mar 17, 2021
1a790bf
reorder args
rok-cesnovar Mar 17, 2021
4b7059f
do not duplicate opencl checks
rok-cesnovar Mar 17, 2021
dcee32e
Merge remote-tracking branch 'origin/opencl_runtime_args' into opencl…
rok-cesnovar Mar 17, 2021
eaf5870
update .Rd files
rok-cesnovar Mar 17, 2021
84d95b4
Merge branch 'master' into opencl_runtime_args
rok-cesnovar Mar 26, 2021
f8fbadb
expand introduction
rok-cesnovar Mar 26, 2021
e83952e
update vignette
rok-cesnovar Apr 1, 2021
46764ed
Merge branch 'master' into opencl_runtime_args
rok-cesnovar Apr 11, 2021
9222345
run test coverage on tests
rok-cesnovar Apr 11, 2021
ff9f619
Merge branch 'master' into opencl_runtime_args
jgabry Apr 14, 2021
4905dd8
gitignore the executables
jgabry Apr 15, 2021
8608e67
minor edits to opencl vignette
jgabry Apr 15, 2021
795b4f8
update vignette
rok-cesnovar Apr 15, 2021
f531df8
remove iteration prints
rok-cesnovar Apr 15, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/Test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ jobs:
shell: Rscript {0}

- name: Test coverage
run: covr::codecov(type = "all")
run: covr::codecov(type = "tests")
shell: Rscript {0}

test-coverage-windows:
Expand Down Expand Up @@ -131,5 +131,5 @@ jobs:
- name: Test coverage
run: |
options(covr.gcov = 'C:/rtools40/mingw64/bin/gcov.exe');
covr::codecov(type = "all", function_exclusions = "sample_mpi")
covr::codecov(type = "tests", function_exclusions = "sample_mpi")
shell: Rscript {0}
16 changes: 12 additions & 4 deletions R/args.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ CmdStanArgs <- R6::R6Class(
output_dir = NULL,
output_basename = NULL,
validate_csv = TRUE,
sig_figs = NULL) {
sig_figs = NULL,
opencl_ids = NULL) {

self$model_name <- model_name
self$exe_file <- exe_file
Expand All @@ -59,7 +60,7 @@ CmdStanArgs <- R6::R6Class(
init <- process_init_list(init, length(self$proc_ids))
}
self$init <- init

self$opencl_ids <- opencl_ids
self$method_args$validate(num_procs = length(self$proc_ids))
self$validate()
},
Expand Down Expand Up @@ -154,7 +155,9 @@ CmdStanArgs <- R6::R6Class(
if (!is.null(profile_file)) {
args$output <- c(args$output, paste0("profile_file=", profile_file))
}

if (!is.null(self$opencl_ids)) {
args$opencl <- c("opencl", paste0("platform=", self$opencl_ids[1]), paste0("device=", self$opencl_ids[2]))
}
args <- do.call(c, append(args, list(use.names = FALSE)))
self$method_args$compose(idx, args)
},
Expand Down Expand Up @@ -496,7 +499,12 @@ validate_cmdstan_args = function(self) {
num_procs <- length(self$proc_ids)
validate_init(self$init, num_procs)
validate_seed(self$seed, num_procs)

if (!is.null(self$opencl_ids)) {
if (cmdstan_version() < "2.25") {
stop("Runtime selection of OpenCL devices is only supported with CmdStan version 2.26 or newer.", call. = FALSE)
}
checkmate::assert_vector(self$opencl_ids, len = 2)
}
invisible(TRUE)
}

Expand Down
33 changes: 27 additions & 6 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,7 @@ sample <- function(data = NULL,
parallel_chains = getOption("mc.cores", 1),
chain_ids = seq_len(chains),
threads_per_chain = NULL,
opencl_ids = NULL,
iter_warmup = NULL,
iter_sampling = NULL,
save_warmup = FALSE,
Expand Down Expand Up @@ -782,6 +783,7 @@ sample <- function(data = NULL,
call. = FALSE)
}
}
check_opencl(self$cpp_options(), opencl_ids)
sample_args <- SampleArgs$new(
iter_warmup = iter_warmup,
iter_sampling = iter_sampling,
Expand Down Expand Up @@ -812,7 +814,8 @@ sample <- function(data = NULL,
output_dir = output_dir,
output_basename = output_basename,
sig_figs = sig_figs,
validate_csv = validate_csv
validate_csv = validate_csv,
opencl_ids = opencl_ids
)
cmdstan_procs <- CmdStanMCMCProcs$new(
num_procs = chains,
Expand Down Expand Up @@ -1021,6 +1024,7 @@ optimize <- function(data = NULL,
output_basename = NULL,
sig_figs = NULL,
threads = NULL,
opencl_ids = NULL,
algorithm = NULL,
init_alpha = NULL,
iter = NULL,
Expand All @@ -1045,6 +1049,7 @@ optimize <- function(data = NULL,
call. = FALSE)
}
}
check_opencl(self$cpp_options(), opencl_ids)
optimize_args <- OptimizeArgs$new(
algorithm = algorithm,
init_alpha = init_alpha,
Expand All @@ -1068,7 +1073,8 @@ optimize <- function(data = NULL,
refresh = refresh,
output_dir = output_dir,
output_basename = output_basename,
sig_figs = sig_figs
sig_figs = sig_figs,
opencl_ids = opencl_ids
)

cmdstan_procs <- CmdStanProcs$new(
Expand Down Expand Up @@ -1143,6 +1149,7 @@ variational <- function(data = NULL,
output_basename = NULL,
sig_figs = NULL,
threads = NULL,
opencl_ids = NULL,
algorithm = NULL,
iter = NULL,
grad_samples = NULL,
Expand All @@ -1168,6 +1175,7 @@ variational <- function(data = NULL,
call. = FALSE)
}
}
check_opencl(self$cpp_options(), opencl_ids)
variational_args <- VariationalArgs$new(
algorithm = algorithm,
iter = iter,
Expand All @@ -1192,7 +1200,8 @@ variational <- function(data = NULL,
refresh = refresh,
output_dir = output_dir,
output_basename = output_basename,
sig_figs = sig_figs
sig_figs = sig_figs,
opencl_ids = opencl_ids
)

cmdstan_procs <- CmdStanProcs$new(
Expand Down Expand Up @@ -1278,7 +1287,8 @@ generate_quantities <- function(fitted_params,
output_basename = NULL,
sig_figs = NULL,
parallel_chains = getOption("mc.cores", 1),
threads_per_chain = NULL) {
threads_per_chain = NULL,
opencl_ids = NULL) {
checkmate::assert_integerish(parallel_chains, lower = 1, null.ok = TRUE)
checkmate::assert_integerish(threads_per_chain, lower = 1, len = 1, null.ok = TRUE)
if (is.null(self$cpp_options()[["stan_threads"]])) {
Expand All @@ -1295,7 +1305,7 @@ generate_quantities <- function(fitted_params,
call. = FALSE)
}
}

check_opencl(self$cpp_options(), opencl_ids)
fitted_params <- process_fitted_params(fitted_params)
chains <- length(fitted_params)
generate_quantities_args <- GenerateQuantitiesArgs$new(
Expand All @@ -1310,7 +1320,8 @@ generate_quantities <- function(fitted_params,
seed = seed,
output_dir = output_dir,
output_basename = output_basename,
sig_figs = sig_figs
sig_figs = sig_figs,
opencl_ids = opencl_ids
)
cmdstan_procs <- CmdStanGQProcs$new(
num_procs = chains,
Expand All @@ -1322,3 +1333,13 @@ generate_quantities <- function(fitted_params,
CmdStanGQ$new(runset)
}
CmdStanModel$set("public", name = "generate_quantities", value = generate_quantities)


check_opencl <- function(cpp_options, opencl_ids) {
if (is.null(cpp_options[["stan_opencl"]])
&& !is.null(opencl_ids)) {
stop("'opencl_ids' is set but the model was not compiled with for use with OpenCL.",
"\nRecompile the model with the 'cpp_options = list(stan_opencl = TRUE)'",
call. = FALSE)
}
}
3 changes: 0 additions & 3 deletions R/run.R
Original file line number Diff line number Diff line change
Expand Up @@ -824,9 +824,6 @@ CmdStanMCMCProcs <- R6::R6Class(
state <- 1.5
next_state <- 1.5
}
if (state < 3 && grepl("profile_file =", line, perl = TRUE)) {
next_state <- 3
}
if (state <= 3 && grepl("Rejecting initial value:", line, perl = TRUE)) {
state <- 2
next_state <- 2
Expand Down
3 changes: 2 additions & 1 deletion _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ articles:
profiling Stan programs, and using CmdStanR in R Markdown documents.
contents:
- cmdstanr-internals
- profiling
- r-markdown
- profiling
- opencl

reference:
- title: "Package description"
Expand Down
12 changes: 7 additions & 5 deletions docs/articles/index.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading