Skip to content

process_data() improvements #538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ to vector, matrix, or array depending on the dimensions of the table. (#528)
* `install_cmdstan()` now automatically installs the Linux ARM CmdStan when
Linux distributions running on ARM CPUs are detected. (#531)

* Improved processing of named lists supplied to the `data` argument to JSON
data files: checking whether the list includes all required elements/Stan
variables; improved differentiating arrays/vectors of length 1 and scalars
when generating JSON data files; generating floating point numbers with
decimal points to fix issue with parsing large numbers. (#538)

# cmdstanr 0.4.0

### Bug fixes
Expand Down
46 changes: 43 additions & 3 deletions R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
#' @export
#' @param data (list) A named list of \R objects.
#' @param file (string) The path to where the data file should be written.
#' @param always_decimal (logical) Force generate non-integers with decimal
#' points to better distinguish between integers and floating point values.
#' If `TRUE` all \R objects in `data` intended for integers must be of integer
#' type.
#'
#' @details
#' `write_stan_json()` performs several conversions before writing the JSON
Expand Down Expand Up @@ -52,7 +56,7 @@
#' write_stan_json(data, file)
#' cat(readLines(file), sep = "\n")
#'
write_stan_json <- function(data, file) {
write_stan_json <- function(data, file, always_decimal = FALSE) {
if (!is.list(data)) {
stop("'data' must be a list.", call. = FALSE)
}
Expand Down Expand Up @@ -99,6 +103,7 @@ write_stan_json <- function(data, file) {
path = file,
auto_unbox = TRUE,
factor = "integer",
always_decimal = always_decimal,
digits = NA,
pretty = TRUE
)
Expand Down Expand Up @@ -133,8 +138,14 @@ list_to_array <- function(x, name = NULL) {
#' @noRd
#' @param data If not `NULL`, then either a path to a data file compatible with
#' CmdStan, or a named list of \R objects to pass to [write_stan_json()].
#' @param stan_file If not `NULL`, the path to the Stan model for which to
#' process the named list suppiled to the `data` argument. The Stan model
#' is used for checking whether the supplied named list has all the
#' required elements/Stan variables and to help differentiate between a
#' vector of length 1 and a scalar when genereting the JSON file. This
#' argument is ignored when a path to a data file is supplied for `data`.
#' @return Path to data file.
process_data <- function(data) {
process_data <- function(data, stan_file = NULL) {
if (length(data) == 0) {
data <- NULL
}
Expand All @@ -151,8 +162,37 @@ process_data <- function(data) {
call. = FALSE
)
}
if (cmdstan_version() > "2.26" && !is.null(stan_file)) {
stan_file <- absolute_path(stan_file)
if (file.exists(stan_file)) {
data_variables <- model_variables(stan_file)$data
is_data_supplied <- names(data_variables) %in% names(data)
if (!all(is_data_supplied)) {
missing <- names(data_variables[!is_data_supplied])
stop(
"Missing input data for the following data variables: ",
paste0(missing, collapse = ", "),
".",
call. = FALSE
)
}
for(var_name in names(data_variables)) {
# distinguish between scalars and arrays/vectors of length 1
if (length(data[[var_name]]) == 1
&& data_variables[[var_name]]$dimensions == 1) {
data[[var_name]] <- array(data[[var_name]], dim = 1)
}
# Make sure integer inputs are of integer type to avoid
# generating a decimal point in write_stan_json
if (data_variables[[var_name]]$type == "int"
&& !is.integer(data[[var_name]])) {
data[[var_name]] <- as.integer(data[[var_name]])
}
}
}
}
path <- tempfile(pattern = "standata-", fileext = ".json")
write_stan_json(data = data, file = path)
write_stan_json(data = data, file = path, always_decimal = (cmdstan_version() > "2.26"))
} else {
stop("'data' should be a path or a named list.", call. = FALSE)
}
Expand Down
12 changes: 6 additions & 6 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ sample <- function(data = NULL,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE),
data_file = process_data(data),
data_file = process_data(data, self$stan_file()),
save_latent_dynamics = save_latent_dynamics,
seed = seed,
init = init,
Expand Down Expand Up @@ -954,7 +954,7 @@ sample_mpi <- function(data = NULL,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE),
data_file = process_data(data),
data_file = process_data(data, self$stan_file()),
save_latent_dynamics = save_latent_dynamics,
seed = seed,
init = init,
Expand Down Expand Up @@ -1059,7 +1059,7 @@ optimize <- function(data = NULL,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = 1,
data_file = process_data(data),
data_file = process_data(data, self$stan_file()),
save_latent_dynamics = save_latent_dynamics,
seed = seed,
init = init,
Expand Down Expand Up @@ -1169,7 +1169,7 @@ variational <- function(data = NULL,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = 1,
data_file = process_data(data),
data_file = process_data(data, self$stan_file()),
save_latent_dynamics = save_latent_dynamics,
seed = seed,
init = init,
Expand Down Expand Up @@ -1271,7 +1271,7 @@ generate_quantities <- function(fitted_params,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = seq_along(fitted_params_files),
data_file = process_data(data),
data_file = process_data(data, self$stan_file()),
seed = seed,
output_dir = output_dir,
output_basename = output_basename,
Expand Down Expand Up @@ -1327,7 +1327,7 @@ diagnose_method <- function(data = NULL,
model_name = self$model_name(),
exe_file = self$exe_file(),
proc_ids = 1,
data_file = process_data(data),
data_file = process_data(data, self$stan_file()),
seed = seed,
init = init,
output_dir = output_dir,
Expand Down
7 changes: 6 additions & 1 deletion man/write_stan_json.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

89 changes: 89 additions & 0 deletions tests/testthat/test-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,40 @@ if (not_on_cran()) {
}

test_that("empty data list converted to NULL", {
skip_on_cran()
stan_file <- write_stan_file("
parameters {
real y;
}
model {
y ~ std_normal();
}
")
expect_null(process_data(list()))
expect_null(process_data(list(), stan_file = stan_file))
})

test_that("process_data works for inputs of length one", {
skip_on_cran()
data <- list(val = 5)
stan_file <- write_stan_file("
data {
real val;
}
")
expect_equal(jsonlite::read_json(process_data(data, stan_file = stan_file)), list(val = 5))
stan_file <- write_stan_file("
data {
int val;
}
")
expect_equal(jsonlite::read_json(process_data(data, stan_file = stan_file)), list(val = 5))
stan_file <- write_stan_file("
data {
vector[1] val;
}
")
expect_equal(jsonlite::read_json(process_data(data, stan_file = stan_file)), list(val = list(5)))
})

test_that("process_fitted_params() works with basic input types", {
Expand Down Expand Up @@ -242,3 +275,59 @@ test_that("process_fitted_params() works with draws_matrix", {
posterior::subset_draws(fit_params_tmp, variable = c("alpha", "beta[1]", "beta[2]", "beta[3]"))
)
})

test_that("process_data() errors on missing variables", {
stan_file <- write_stan_file("
data {
real val1;
real val2;
}
")
expect_error(
process_data(data = list(val1 = 5), stan_file = stan_file),
"Missing input data for the following data variables: val2."
)
expect_error(
process_data(data = list(val = 1), stan_file = stan_file),
"Missing input data for the following data variables: val1, val2."
)
stan_file_no_data <- write_stan_file("
transformed data {
real val1 = 1;
real val2 = 2;
}
")
v <- process_data(data = list(val1 = 5), stan_file = stan_file_no_data)
expect_type(v, "character")
})

test_that("process_data() corrrectly casts integers and floating point numbers", {
stan_file <- write_stan_file("
data {
int a;
real b;
}
")
test_file <- process_data(list(a = 1, b = 2), stan_file = stan_file)
expect_match(
" \"a\": 1,",
readLines(test_file)[2],
fixed = TRUE
)
expect_match(
" \"b\": 2.0",
readLines(test_file)[3],
fixed = TRUE
)
test_file <- process_data(list(a = 1L, b = 1774000000), stan_file = stan_file)
expect_match(
" \"a\": 1,",
readLines(test_file)[2],
fixed = TRUE
)
expect_match(
" \"b\": 1774000000.0",
readLines(test_file)[3],
fixed = TRUE
)
})
18 changes: 3 additions & 15 deletions tests/testthat/test-failed-chains.R
Original file line number Diff line number Diff line change
Expand Up @@ -158,22 +158,10 @@ test_that("init warnings are shown", {

test_that("optimize error on bad data", {
mod <- testing_model("bernoulli")
suppressWarnings(
expect_message(
mod$optimize(data = list(a = c(1,2,3)), seed = 123),
"Exception: variable does not exist"
)
)
expect_warning(
utils::capture.output(
fit <- mod$optimize(data = list(a = c(1,2,3)), seed = 123)
),
"Fitting finished unexpectedly!"
expect_error(
mod$optimize(data = list(a = c(1,2,3)), seed = 123),
"Missing input data for the following data variables: N, y."
)
expect_error(fit$print(), "Fitting failed. Unable to print.")
expect_error(fit$summary(), "Fitting failed. Unable to retrieve the draws.")
expect_error(fit$draws(), "Fitting failed. Unable to retrieve the draws.")
expect_error(fit$metadata(), "Fitting failed. Unable to retrieve the metadata.")
})

test_that("errors when using draws after variational fais", {
Expand Down
26 changes: 26 additions & 0 deletions tests/testthat/test-json.R
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,29 @@ test_that("write_stan_json() errors if bad names", {
"All elements in 'data' list must have names"
)
})

test_that("write_stan_json() works with always_decimal = TRUE", {
test_file <- tempfile(fileext = ".json")
write_stan_json(list(a = 1L, b = 2), test_file, always_decimal = FALSE)
expect_match(
" \"a\": 1,",
readLines(test_file)[2],
fixed = TRUE
)
expect_match(
" \"b\": 2",
readLines(test_file)[3],
fixed = TRUE
)
write_stan_json(list(a = 1L, b = 2), test_file, always_decimal = TRUE)
expect_match(
" \"a\": 1,",
readLines(test_file)[2],
fixed = TRUE
)
expect_match(
" \"b\": 2.0",
readLines(test_file)[3],
fixed = TRUE
)
})