Skip to content

Commit d7190ba

Browse files
authored
Merge pull request #414 from stan-dev/time_in_read_cmdstan_csv
Return time in read_cmdstan_csv for MCMC
2 parents 102663b + 586e9a5 commit d7190ba

File tree

7 files changed

+100
-7
lines changed

7 files changed

+100
-7
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: cmdstanr
22
Title: R Interface to 'CmdStan'
3-
Version: 0.3.0
3+
Version: 0.3.0.9000
44
Date: 2020-12-17
55
Authors@R:
66
c(person(given = "Jonah", family = "Gabry", role = c("aut", "cre"),

NEWS.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
# Items for next tagged release
2+
3+
### Bug fixes
4+
5+
### New features
6+
7+
* `read_cmdstan_csv` now also returns time for MCMC sampling CSV files.
8+
19
# cmdstanr 0.3.0
210

311
### Bug fixes

R/fit.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -514,8 +514,9 @@ CmdStanFit$set("public", name = "data_file", value = data_file)
514514
#'
515515
#' @return
516516
#' A list with elements
517-
#' * `total`: (scalar) the total run time.
518-
#' * `chains`: (data frame) for MCMC only, timing info for the individual
517+
#' * `total`: (scalar) The total run time. For MCMC this may be different than
518+
#' the sum of the chain run times if parallelization was used.
519+
#' * `chains`: (data frame) For MCMC only, timing info for the individual
519520
#' chains. The data frame has columns `"chain_id"`, `"warmup"`, `"sampling"`,
520521
#' and `"total"`.
521522
#'

R/read_csv.R

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
#' For [sampling][model-method-sample] the returned list also includes the
2929
#' following components:
3030
#'
31+
#' * `time`: Run time information for the individual chains. The returned object
32+
#' is the same as for the [$time()][fit-method-time] method except the total run
33+
#' time can't be inferred from the CSV files (the chains may have been run in
34+
#' parallel) and is therefore `NA`.
3135
#' * `inv_metric`: A list (one element per chain) of inverse mass matrices
3236
#' or their diagonals, depending on the type of metric used.
3337
#' * `step_size`: A list (one element per chain) of the step sizes used.
@@ -120,6 +124,8 @@ read_cmdstan_csv <- function(files,
120124
step_size <- list()
121125
col_types <- NULL
122126
col_select <- NULL
127+
metadata <- NULL
128+
time <- data.frame()
123129
not_matching <- c()
124130
for (output_file in files) {
125131
if (is.null(metadata)) {
@@ -130,7 +136,9 @@ read_cmdstan_csv <- function(files,
130136
if (!is.null(metadata$step_size_adaptation)) {
131137
step_size[[as.character(metadata$id)]] <- metadata$step_size_adaptation
132138
}
133-
id <- metadata$id
139+
if (!is.null(metadata$time)) {
140+
time <- rbind(time, metadata$time)
141+
}
134142
} else {
135143
csv_file_info <- read_csv_metadata(output_file)
136144
check <- check_csv_metadata_matches(metadata, csv_file_info)
@@ -151,7 +159,9 @@ read_cmdstan_csv <- function(files,
151159
if (!is.null(csv_file_info$step_size_adaptation)) {
152160
step_size[[as.character(csv_file_info$id)]] <- csv_file_info$step_size_adaptation
153161
}
154-
id <- csv_file_info$id
162+
if (!is.null(csv_file_info$time)) {
163+
time <- rbind(time, csv_file_info$time)
164+
}
155165
}
156166
if (is.null(col_select)) {
157167
if (is.null(variables)) { # variables = NULL returns all
@@ -321,6 +331,7 @@ read_cmdstan_csv <- function(files,
321331
}
322332
list(
323333
metadata = metadata,
334+
time = list(total = NA_integer_, chains = time),
324335
inv_metric = inv_metric,
325336
step_size = step_size,
326337
warmup_draws = warmup_draws,
@@ -393,6 +404,9 @@ read_csv_metadata <- function(csv_file) {
393404
inv_metric_rows <- -1
394405
parsing_done <- FALSE
395406
dense_inv_metric <- FALSE
407+
warmup_time <- 0
408+
sampling_time <-0
409+
total_time <- 0
396410
if (os_is_windows()) {
397411
grep_path <- repair_path(Sys.which("grep.exe"))
398412
fread_cmd <- paste0(grep_path, " '^[#a-zA-Z]' --color=never ", csv_file)
@@ -467,6 +481,16 @@ read_csv_metadata <- function(csv_file) {
467481
csv_file_info[[key_val[1]]] <- key_val[2]
468482
}
469483
}
484+
} else if (grepl("(Warm-up)", tmp, fixed = TRUE)) {
485+
tmp <- gsub("Elapsed Time:", "", tmp, fixed = TRUE)
486+
tmp <- gsub("seconds (Warm-up)", "", tmp, fixed = TRUE)
487+
warmup_time <- as.numeric(tmp)
488+
} else if (grepl("(Sampling)", tmp, fixed = TRUE)) {
489+
tmp <- gsub("seconds (Sampling)", "", tmp, fixed = TRUE)
490+
sampling_time <- as.numeric(tmp)
491+
} else if (grepl("(Total)", tmp, fixed = TRUE)) {
492+
tmp <- gsub("seconds (Total)", "", tmp, fixed = TRUE)
493+
total_time <- as.numeric(tmp)
470494
}
471495
}
472496
}
@@ -493,6 +517,14 @@ read_csv_metadata <- function(csv_file) {
493517
} else {
494518
csv_file_info$threads_per_chain <- csv_file_info$num_threads
495519
}
520+
if (csv_file_info$method == "sample") {
521+
csv_file_info$time <- data.frame(
522+
chain_id = csv_file_info$id,
523+
warmup = warmup_time,
524+
sampling = sampling_time,
525+
total = total_time
526+
)
527+
}
496528
csv_file_info$model <- NULL
497529
csv_file_info$engaged <- NULL
498530
csv_file_info$delta <- NULL

man/fit-method-time.Rd

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/read_cmdstan_csv.Rd

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-csv.R

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,3 +471,50 @@ test_that("stan_variables and stan_variable_dims works in read_cdmstan_csv()", {
471471
expect_equal(gq$metadata$stan_variable_dims, list(y_rep = 10, sum_y = 1))
472472
})
473473

474+
test_that("returning time works for read_cmdstan_csv", {
475+
csv_files <- test_path("resources", "csv", "model1-2-no-warmup.csv")
476+
csv_data <- read_cmdstan_csv(csv_files)
477+
expect_equal(csv_data$time$total, NA_integer_)
478+
expect_equal(csv_data$time$chains, data.frame(
479+
chain_id = 2,
480+
warmup = 0.017041,
481+
sampling = 0.022068,
482+
total = 0.039109
483+
))
484+
485+
csv_files <- test_path("resources", "csv", "model1-3-diff_args.csv")
486+
csv_data <- read_cmdstan_csv(csv_files)
487+
expect_equal(csv_data$time$total, NA_integer_)
488+
expect_equal(csv_data$time$chains, data.frame(
489+
chain_id = 1,
490+
warmup = 0.038029,
491+
sampling = 0.030711,
492+
total = 0.06874
493+
))
494+
495+
csv_files <- c(
496+
test_path("resources", "csv", "model1-1-warmup.csv"),
497+
test_path("resources", "csv", "model1-2-warmup.csv")
498+
)
499+
csv_data <- read_cmdstan_csv(csv_files)
500+
expect_equal(csv_data$time$total, NA_integer_)
501+
expect_equal(csv_data$time$chains, data.frame(
502+
chain_id = c(1,2),
503+
warmup = c(0.038029, 0.017041),
504+
sampling = c(0.030711, 0.022068),
505+
total = c(0.06874, 0.039109)
506+
))
507+
csv_files <- c(
508+
test_path("resources", "csv", "bernoulli-1-optimize.csv")
509+
)
510+
csv_data <- read_cmdstan_csv(csv_files)
511+
expect_null(csv_data$time$chains)
512+
})
513+
514+
test_that("time from read_cmdstan_csv matches time from fit$time()", {
515+
fit <- fit_bernoulli_thin_1
516+
expect_equivalent(
517+
read_cmdstan_csv(fit$output_files())$time$chains,
518+
fit$time()$chains
519+
)
520+
})

0 commit comments

Comments
 (0)