Skip to content

Commit f8f6321

Browse files
authored
Merge pull request #696 from andrjohns/wsl-cmdstan-internal
WSL - Run `cmdstan` and models under WSL filesystem
2 parents 53084da + 71c0ea7 commit f8f6321

20 files changed

+410
-171
lines changed

.github/workflows/R-CMD-check-wsl.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,14 @@ jobs:
5454
run: |
5555
remotes::install_deps(dependencies = TRUE)
5656
remotes::install_cran("rcmdcheck")
57-
remotes::install_local(path = ".")
57+
remotes::install_local(path = ".", INSTALL_opts = "--no-test-load")
5858
install.packages("curl")
5959
shell: Rscript {0}
6060

6161
- uses: Vampire/setup-wsl@v1
6262
with:
6363
distribution: Ubuntu-22.04
64-
use-cache: 'true'
64+
use-cache: 'false'
6565
set-as-default: 'true'
6666
- name: Install WSL Dependencies
6767
run: |
@@ -74,6 +74,7 @@ jobs:
7474

7575
- name: Install cmdstan
7676
run: |
77+
cmdstanr::check_cmdstan_toolchain(fix = TRUE)
7778
cmdstanr::install_cmdstan(cores = 2, wsl = TRUE, overwrite = TRUE)
7879
shell: Rscript {0}
7980

.github/workflows/R-CMD-check.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ jobs:
7979

8080
- name: Install dependencies
8181
run: |
82+
Sys.setenv("MAKEFLAGS"="-j2")
8283
remotes::install_deps(dependencies = TRUE)
8384
remotes::install_cran("rcmdcheck")
8485
remotes::install_local(path = ".")

.github/workflows/Test-coverage.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ jobs:
5959
- name: Install dependencies
6060
run: |
6161
install.packages(c("remotes", "curl"), dependencies = TRUE)
62-
remotes::install_local(path = ".")
62+
remotes::install_local(path = ".", INSTALL_opts = "--no-test-load")
6363
remotes::install_deps(dependencies = TRUE)
6464
remotes::install_cran("covr")
6565
remotes::install_cran("gridExtra")

.github/workflows/cmdstan-tarball-check.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ jobs:
6666
run: |
6767
remotes::install_deps(dependencies = TRUE)
6868
remotes::install_cran("rcmdcheck")
69-
remotes::install_local(path = ".")
69+
remotes::install_local(path = ".", INSTALL_opts = "--no-test-load")
7070
cmdstanr::check_cmdstan_toolchain(fix = TRUE)
7171
if (Sys.getenv("CMDSTAN_TEST_TARBALL_URL") == "latest") {
7272
cmdstanr::install_cmdstan(cores = 2, overwrite = TRUE)

R/args.R

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,20 @@ CmdStanArgs <- R6::R6Class(
5757
self$save_latent_dynamics <- save_latent_dynamics
5858
self$using_tempdir <- is.null(output_dir)
5959
self$model_variables <- model_variables
60-
if (getRversion() < "3.5.0") {
60+
if (os_is_wsl()) {
61+
# Want to ensure that any files under WSL are written to a tempdir within
62+
# WSL to avoid IO performance issues
63+
self$output_dir <- ifelse(is.null(output_dir),
64+
file.path(wsl_dir_prefix(), wsl_tempdir()),
65+
wsl_safe_path(output_dir))
66+
} else if (getRversion() < "3.5.0") {
6167
self$output_dir <- output_dir %||% tempdir()
6268
} else {
63-
self$output_dir <- output_dir %||% tempdir(check = TRUE)
69+
if (getRversion() < "3.5.0") {
70+
self$output_dir <- output_dir %||% tempdir()
71+
} else {
72+
self$output_dir <- output_dir %||% tempdir(check = TRUE)
73+
}
6474
}
6575
self$output_dir <- repair_path(self$output_dir)
6676
self$output_basename <- output_basename
@@ -525,8 +535,7 @@ DiagnoseArgs <- R6::R6Class(
525535
#' @return `TRUE` invisibly unless an error is thrown.
526536
validate_cmdstan_args <- function(self) {
527537
validate_exe_file(self$exe_file)
528-
529-
checkmate::assert_directory_exists(self$output_dir, access = "rw")
538+
assert_dir_exists(self$output_dir, access = "rw")
530539

531540
# at least 1 run id (chain id)
532541
checkmate::assert_integerish(self$proc_ids,
@@ -545,7 +554,7 @@ validate_cmdstan_args <- function(self) {
545554
self$refresh <- as.integer(self$refresh)
546555
}
547556
if (!is.null(self$data_file)) {
548-
checkmate::assert_file_exists(self$data_file, access = "r")
557+
assert_file_exists(self$data_file, access = "r")
549558
}
550559
num_procs <- length(self$proc_ids)
551560
validate_init(self$init, num_procs)
@@ -698,7 +707,7 @@ validate_optimize_args <- function(self) {
698707
#' @return `TRUE` invisibly unless an error is thrown.
699708
validate_generate_quantities_args <- function(self) {
700709
if (!is.null(self$fitted_params)) {
701-
checkmate::assert_file_exists(self$fitted_params, access = "r")
710+
assert_file_exists(self$fitted_params, access = "r")
702711
}
703712

704713
invisible(TRUE)
@@ -895,7 +904,7 @@ validate_init <- function(init, num_procs) {
895904
"length 1 or number of chains.",
896905
call. = FALSE)
897906
}
898-
checkmate::assert_file_exists(init, access = "r")
907+
assert_file_exists(init, access = "r")
899908
}
900909

901910
invisible(TRUE)
@@ -983,7 +992,7 @@ validate_metric_file <- function(metric_file, num_procs) {
983992
return(invisible(TRUE))
984993
}
985994

986-
checkmate::assert_file_exists(metric_file, access = "r")
995+
assert_file_exists(metric_file, access = "r")
987996

988997
if (length(metric_file) != 1 && length(metric_file) != num_procs) {
989998
stop(length(metric_file), " metric(s) provided. Must provide ",

R/csv.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ read_cmdstan_csv <- function(files,
125125
sampler_diagnostics = NULL,
126126
format = getOption("cmdstanr_draws_format", NULL)) {
127127
format <- assert_valid_draws_format(format)
128-
checkmate::assert_file_exists(files, access = "r", extension = "csv")
128+
assert_file_exists(files, access = "r", extension = "csv")
129129
metadata <- NULL
130130
warmup_draws <- list()
131131
draws <- list()
@@ -237,7 +237,7 @@ read_cmdstan_csv <- function(files,
237237
fread_cmd <- paste0(
238238
grep_path_quotes,
239239
" -v \"^#\" --color=never \"",
240-
output_file,
240+
wsl_safe_path(output_file, revert = TRUE),
241241
"\""
242242
)
243243
} else {
@@ -556,7 +556,7 @@ for (method in unavailable_methods_CmdStanFit_CSV) {
556556
#' mass matrix (or its diagonal depending on the metric).
557557
#'
558558
read_csv_metadata <- function(csv_file) {
559-
checkmate::assert_file_exists(csv_file, access = "r", extension = "csv")
559+
assert_file_exists(csv_file, access = "r", extension = "csv")
560560
inv_metric_next <- FALSE
561561
csv_file_info <- list()
562562
csv_file_info$inv_metric <- NULL
@@ -579,7 +579,7 @@ read_csv_metadata <- function(csv_file) {
579579
fread_cmd <- paste0(
580580
grep_path_quotes,
581581
" \"^[#a-zA-Z]\" --color=never \"",
582-
csv_file,
582+
wsl_safe_path(csv_file, revert = TRUE),
583583
"\""
584584
)
585585
} else {

R/fit.R

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ CmdStanFit <- R6::R6Class(
1919
if (!is.null(private$model_methods_env_$model_ptr)) {
2020
initialize_model_pointer(private$model_methods_env_, self$data_file(), 0)
2121
}
22+
# Need to update the output directory path to one that can be accessed
23+
# from Windows, for the post-processing of results
24+
self$runset$args$output_dir <- wsl_safe_path(self$runset$args$output_dir,
25+
revert = TRUE)
2226
invisible(self)
2327
},
2428
num_procs = function() {
@@ -303,6 +307,11 @@ CmdStanFit$set("public", name = "init", value = init)
303307
#' }
304308
#'
305309
init_model_methods <- function(seed = 0, verbose = FALSE, hessian = FALSE) {
310+
if (os_is_wsl()) {
311+
stop("Additional model methods are not currently available with ",
312+
"WSL CmdStan and will not be compiled",
313+
call. = FALSE)
314+
}
306315
require_suggested_package("Rcpp")
307316
require_suggested_package("RcppEigen")
308317
if (length(private$model_methods_env_$hpp_code_) == 0) {

R/install.R

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,10 @@ install_cmdstan <- function(dir = NULL,
9393
call. = FALSE)
9494
wsl <- FALSE
9595
} else {
96-
Sys.setenv("CMDSTANR_USE_WSL" = 1)
96+
.cmdstanr$WSL <- TRUE
9797
}
98+
} else {
99+
.cmdstanr$WSL <- FALSE
98100
}
99101
if (check_toolchain) {
100102
check_cmdstan_toolchain(fix = FALSE, quiet = quiet)
@@ -108,13 +110,13 @@ install_cmdstan <- function(dir = NULL,
108110
}
109111
}
110112
if (is.null(dir)) {
111-
dir <- cmdstan_default_install_path()
113+
dir <- cmdstan_default_install_path(wsl = wsl)
112114
if (!dir.exists(dir)) {
113115
dir.create(dir, recursive = TRUE)
114116
}
115117
} else {
116118
dir <- repair_path(dir)
117-
checkmate::assert_directory_exists(dir, access = "rwx")
119+
assert_dir_exists(dir, access = "rwx")
118120
}
119121
if (!is.null(version)) {
120122
if (!is.null(release_url)) {
@@ -125,7 +127,6 @@ install_cmdstan <- function(dir = NULL,
125127
release_url <- paste0("https://github.com/stan-dev/cmdstan/releases/download/v",
126128
version, "/cmdstan-", version, cmdstan_arch_suffix(version), ".tar.gz")
127129
}
128-
wsl_prefix <- ifelse(isTRUE(wsl), "wsl-", "")
129130
if (!is.null(release_url)) {
130131
if (!endsWith(release_url, ".tar.gz")) {
131132
stop(release_url, " is not a .tar.gz archive!",
@@ -137,14 +138,14 @@ install_cmdstan <- function(dir = NULL,
137138
tar_name <- utils::tail(split_url[[1]], n = 1)
138139
cmdstan_ver <- substr(tar_name, 0, nchar(tar_name) - 7)
139140
tar_gz_file <- paste0(cmdstan_ver, ".tar.gz")
140-
dir_cmdstan <- file.path(dir, paste0(wsl_prefix, cmdstan_ver))
141+
dir_cmdstan <- file.path(dir, cmdstan_ver)
141142
dest_file <- file.path(dir, tar_gz_file)
142143
} else {
143144
ver <- latest_released_version()
144145
message("* Latest CmdStan release is v", ver)
145146
cmdstan_ver <- paste0("cmdstan-", ver, cmdstan_arch_suffix(ver))
146147
tar_gz_file <- paste0(cmdstan_ver, ".tar.gz")
147-
dir_cmdstan <- file.path(dir, paste0(wsl_prefix, cmdstan_ver))
148+
dir_cmdstan <- file.path(dir, cmdstan_ver)
148149
message("* Installing CmdStan v", ver, " in ", dir_cmdstan)
149150
message("* Downloading ", tar_gz_file, " from GitHub...")
150151
download_url <- github_download_url(ver)
@@ -164,17 +165,34 @@ install_cmdstan <- function(dir = NULL,
164165
stop("Download of CmdStan failed. Please try again.", call. = FALSE)
165166
}
166167
message("* Download complete")
167-
168168
message("* Unpacking archive...")
169-
untar_rc <- utils::untar(
170-
dest_file,
171-
exdir = dir_cmdstan,
172-
extras = "--strip-components 1"
173-
)
174-
if (untar_rc != 0) {
175-
stop("Problem extracting tarball. Exited with return code: ", untar_rc, call. = FALSE)
169+
if (wsl) {
170+
# Significantly faster to use WSL to untar the downloaded archive, as there are
171+
# similar IO issues accessing the WSL filesystem from windows
172+
wsl_tar_gz_file <- gsub(paste0("//wsl$/", wsl_distro_name()), "",
173+
dest_file, fixed = TRUE)
174+
wsl_tar_gz_file <- wsl_safe_path(wsl_tar_gz_file)
175+
untar_rc <- processx::run(
176+
command = "wsl",
177+
args = c("tar", "-xf", wsl_tar_gz_file, "-C",
178+
gsub(tar_gz_file, "", wsl_tar_gz_file))
179+
)
180+
remove_rc <- processx::run(
181+
command = "wsl",
182+
args = c("rm", wsl_tar_gz_file)
183+
)
184+
} else {
185+
untar_rc <- utils::untar(
186+
dest_file,
187+
exdir = dir_cmdstan,
188+
extras = "--strip-components 1"
189+
)
190+
if (untar_rc != 0) {
191+
stop("Problem extracting tarball. Exited with return code: ", untar_rc, call. = FALSE)
192+
}
193+
file.remove(dest_file)
176194
}
177-
file.remove(dest_file)
195+
178196
cmdstan_make_local(dir = dir_cmdstan, cpp_options = cpp_options, append = TRUE)
179197
# Setting up native M1 compilation of CmdStan and its downstream libraries
180198
if (is_rosetta2()) {
@@ -186,7 +204,7 @@ install_cmdstan <- function(dir = NULL,
186204
append = TRUE
187205
)
188206
}
189-
if (is_rtools42_toolchain() && !os_is_wsl()) {
207+
if (is_rtools42_toolchain() && !wsl) {
190208
cmdstan_make_local(
191209
dir = dir_cmdstan,
192210
cpp_options = list(
@@ -521,10 +539,7 @@ install_toolchain <- function(quiet = FALSE) {
521539
}
522540

523541
check_wsl_toolchain <- function() {
524-
wsl_inaccessible <- processx::run(command = "wsl",
525-
args = "uname",
526-
error_on_status = FALSE)
527-
if (wsl_inaccessible$status) {
542+
if (!wsl_installed()) {
528543
stop("\n", "A WSL distribution is not installed or is not accessible.",
529544
"\n", "Please see the Microsoft documentation for guidance on installing WSL: ",
530545
"\n", "https://docs.microsoft.com/en-us/windows/wsl/install",

R/model.R

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ CmdStanModel <- R6::R6Class(
229229
self$functions <- new.env()
230230
self$functions$compiled <- FALSE
231231
if (!is.null(stan_file)) {
232-
checkmate::assert_file_exists(stan_file, access = "r", extension = "stan")
232+
assert_file_exists(stan_file, access = "r", extension = "stan")
233233
checkmate::assert_flag(compile)
234234
private$stan_file_ <- absolute_path(stan_file)
235235
private$stan_code_ <- readLines(stan_file)
@@ -250,7 +250,7 @@ CmdStanModel <- R6::R6Class(
250250
ext <- if (os_is_windows() && !os_is_wsl()) "exe" else ""
251251
private$exe_file_ <- repair_path(absolute_path(exe_file))
252252
if (is.null(stan_file)) {
253-
checkmate::assert_file_exists(private$exe_file_, access = "r", extension = ext)
253+
assert_file_exists(private$exe_file_, access = "r", extension = ext)
254254
private$model_name_ <- sub(" ", "_", strip_ext(basename(private$exe_file_)))
255255
}
256256
}
@@ -317,7 +317,7 @@ CmdStanModel <- R6::R6Class(
317317
if (is.null(dir)) {
318318
dir <- dirname(private$stan_file_)
319319
}
320-
checkmate::assert_directory_exists(dir, access = "r")
320+
assert_dir_exists(dir, access = "r")
321321
new_hpp_loc <- file.path(dir, paste0(strip_ext(basename(private$stan_file_)), ".hpp"))
322322
file.copy(self$hpp_file(), new_hpp_loc, overwrite = TRUE)
323323
file.remove(self$hpp_file())
@@ -471,7 +471,7 @@ compile <- function(quiet = TRUE,
471471
}
472472
if (!is.null(dir)) {
473473
dir <- repair_path(dir)
474-
checkmate::assert_directory_exists(dir, access = "rw")
474+
assert_dir_exists(dir, access = "rw")
475475
if (length(self$exe_file()) != 0) {
476476
private$exe_file_ <- file.path(dir, basename(self$exe_file()))
477477
}
@@ -524,6 +524,15 @@ compile <- function(quiet = TRUE,
524524
}
525525
}
526526

527+
if (os_is_wsl() && (compile_model_methods || compile_standalone)) {
528+
warning("Additional model methods and standalone functions are not ",
529+
"currently available with WSL CmdStan and will not be compiled",
530+
call. = FALSE)
531+
compile_model_methods <- FALSE
532+
compile_standalone <- FALSE
533+
compile_hessian_method <- FALSE
534+
}
535+
527536
temp_stan_file <- tempfile(pattern = "model-", fileext = ".stan")
528537
file.copy(self$stan_file(), temp_stan_file, overwrite = TRUE)
529538
temp_file_no_ext <- strip_ext(temp_stan_file)
@@ -629,6 +638,12 @@ compile <- function(quiet = TRUE,
629638
file.remove(exe)
630639
}
631640
file.copy(tmp_exe, exe, overwrite = TRUE)
641+
if (os_is_wsl()) {
642+
res <- processx::run(
643+
command = "wsl",
644+
args = c("chmod", "+x", wsl_safe_path(exe))
645+
)
646+
}
632647
private$exe_file_ <- exe
633648
private$cpp_options_ <- cpp_options
634649
private$precompile_cpp_options_ <- NULL
@@ -1806,7 +1821,7 @@ cpp_options_to_compile_flags <- function(cpp_options) {
18061821
include_paths_stanc3_args <- function(include_paths = NULL) {
18071822
stancflags <- NULL
18081823
if (!is.null(include_paths)) {
1809-
checkmate::assert_directory_exists(include_paths, access = "r")
1824+
assert_dir_exists(include_paths, access = "r")
18101825
include_paths <- sapply(absolute_path(include_paths), wsl_safe_path)
18111826
paths_w_space <- grep(" ", include_paths)
18121827
include_paths[paths_w_space] <- paste0("'", include_paths[paths_w_space], "'")

0 commit comments

Comments
 (0)