Skip to content

Commit e3f20f7

Browse files
authored
Merge pull request #538 from stan-dev/process_scalars_and_length_1_containers
`process_data()` improvements
2 parents 58f0980 + c08647c commit e3f20f7

File tree

7 files changed

+179
-25
lines changed

7 files changed

+179
-25
lines changed

NEWS.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ to vector, matrix, or array depending on the dimensions of the table. (#528)
3030
* `install_cmdstan()` now automatically installs the Linux ARM CmdStan when
3131
Linux distributions running on ARM CPUs are detected. (#531)
3232

33+
* Improved processing of named lists supplied to the `data` argument to JSON
34+
data files: checking whether the list includes all required elements/Stan
35+
variables; improved differentiating arrays/vectors of length 1 and scalars
36+
when generating JSON data files; generating floating point numbers with
37+
decimal points to fix issue with parsing large numbers. (#538)
38+
3339
# cmdstanr 0.4.0
3440

3541
### Bug fixes

R/data.R

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
#' @export
44
#' @param data (list) A named list of \R objects.
55
#' @param file (string) The path to where the data file should be written.
6+
#' @param always_decimal (logical) Force generate non-integers with decimal
7+
#' points to better distinguish between integers and floating point values.
8+
#' If `TRUE` all \R objects in `data` intended for integers must be of integer
9+
#' type.
610
#'
711
#' @details
812
#' `write_stan_json()` performs several conversions before writing the JSON
@@ -52,7 +56,7 @@
5256
#' write_stan_json(data, file)
5357
#' cat(readLines(file), sep = "\n")
5458
#'
55-
write_stan_json <- function(data, file) {
59+
write_stan_json <- function(data, file, always_decimal = FALSE) {
5660
if (!is.list(data)) {
5761
stop("'data' must be a list.", call. = FALSE)
5862
}
@@ -99,6 +103,7 @@ write_stan_json <- function(data, file) {
99103
path = file,
100104
auto_unbox = TRUE,
101105
factor = "integer",
106+
always_decimal = always_decimal,
102107
digits = NA,
103108
pretty = TRUE
104109
)
@@ -133,8 +138,14 @@ list_to_array <- function(x, name = NULL) {
133138
#' @noRd
134139
#' @param data If not `NULL`, then either a path to a data file compatible with
135140
#' CmdStan, or a named list of \R objects to pass to [write_stan_json()].
141+
#' @param stan_file If not `NULL`, the path to the Stan model for which to
142+
#' process the named list suppiled to the `data` argument. The Stan model
143+
#' is used for checking whether the supplied named list has all the
144+
#' required elements/Stan variables and to help differentiate between a
145+
#' vector of length 1 and a scalar when genereting the JSON file. This
146+
#' argument is ignored when a path to a data file is supplied for `data`.
136147
#' @return Path to data file.
137-
process_data <- function(data) {
148+
process_data <- function(data, stan_file = NULL) {
138149
if (length(data) == 0) {
139150
data <- NULL
140151
}
@@ -151,8 +162,37 @@ process_data <- function(data) {
151162
call. = FALSE
152163
)
153164
}
165+
if (cmdstan_version() > "2.26" && !is.null(stan_file)) {
166+
stan_file <- absolute_path(stan_file)
167+
if (file.exists(stan_file)) {
168+
data_variables <- model_variables(stan_file)$data
169+
is_data_supplied <- names(data_variables) %in% names(data)
170+
if (!all(is_data_supplied)) {
171+
missing <- names(data_variables[!is_data_supplied])
172+
stop(
173+
"Missing input data for the following data variables: ",
174+
paste0(missing, collapse = ", "),
175+
".",
176+
call. = FALSE
177+
)
178+
}
179+
for(var_name in names(data_variables)) {
180+
# distinguish between scalars and arrays/vectors of length 1
181+
if (length(data[[var_name]]) == 1
182+
&& data_variables[[var_name]]$dimensions == 1) {
183+
data[[var_name]] <- array(data[[var_name]], dim = 1)
184+
}
185+
# Make sure integer inputs are of integer type to avoid
186+
# generating a decimal point in write_stan_json
187+
if (data_variables[[var_name]]$type == "int"
188+
&& !is.integer(data[[var_name]])) {
189+
data[[var_name]] <- as.integer(data[[var_name]])
190+
}
191+
}
192+
}
193+
}
154194
path <- tempfile(pattern = "standata-", fileext = ".json")
155-
write_stan_json(data = data, file = path)
195+
write_stan_json(data = data, file = path, always_decimal = (cmdstan_version() > "2.26"))
156196
} else {
157197
stop("'data' should be a path or a named list.", call. = FALSE)
158198
}

R/model.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@ sample <- function(data = NULL,
819819
model_name = self$model_name(),
820820
exe_file = self$exe_file(),
821821
proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE),
822-
data_file = process_data(data),
822+
data_file = process_data(data, self$stan_file()),
823823
save_latent_dynamics = save_latent_dynamics,
824824
seed = seed,
825825
init = init,
@@ -954,7 +954,7 @@ sample_mpi <- function(data = NULL,
954954
model_name = self$model_name(),
955955
exe_file = self$exe_file(),
956956
proc_ids = checkmate::assert_integerish(chain_ids, lower = 1, len = chains, unique = TRUE, null.ok = FALSE),
957-
data_file = process_data(data),
957+
data_file = process_data(data, self$stan_file()),
958958
save_latent_dynamics = save_latent_dynamics,
959959
seed = seed,
960960
init = init,
@@ -1059,7 +1059,7 @@ optimize <- function(data = NULL,
10591059
model_name = self$model_name(),
10601060
exe_file = self$exe_file(),
10611061
proc_ids = 1,
1062-
data_file = process_data(data),
1062+
data_file = process_data(data, self$stan_file()),
10631063
save_latent_dynamics = save_latent_dynamics,
10641064
seed = seed,
10651065
init = init,
@@ -1169,7 +1169,7 @@ variational <- function(data = NULL,
11691169
model_name = self$model_name(),
11701170
exe_file = self$exe_file(),
11711171
proc_ids = 1,
1172-
data_file = process_data(data),
1172+
data_file = process_data(data, self$stan_file()),
11731173
save_latent_dynamics = save_latent_dynamics,
11741174
seed = seed,
11751175
init = init,
@@ -1271,7 +1271,7 @@ generate_quantities <- function(fitted_params,
12711271
model_name = self$model_name(),
12721272
exe_file = self$exe_file(),
12731273
proc_ids = seq_along(fitted_params_files),
1274-
data_file = process_data(data),
1274+
data_file = process_data(data, self$stan_file()),
12751275
seed = seed,
12761276
output_dir = output_dir,
12771277
output_basename = output_basename,
@@ -1327,7 +1327,7 @@ diagnose_method <- function(data = NULL,
13271327
model_name = self$model_name(),
13281328
exe_file = self$exe_file(),
13291329
proc_ids = 1,
1330-
data_file = process_data(data),
1330+
data_file = process_data(data, self$stan_file()),
13311331
seed = seed,
13321332
init = init,
13331333
output_dir = output_dir,

man/write_stan_json.Rd

Lines changed: 6 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-data.R

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,40 @@ if (not_on_cran()) {
88
}
99

1010
test_that("empty data list converted to NULL", {
11+
skip_on_cran()
12+
stan_file <- write_stan_file("
13+
parameters {
14+
real y;
15+
}
16+
model {
17+
y ~ std_normal();
18+
}
19+
")
1120
expect_null(process_data(list()))
21+
expect_null(process_data(list(), stan_file = stan_file))
22+
})
23+
24+
test_that("process_data works for inputs of length one", {
25+
skip_on_cran()
26+
data <- list(val = 5)
27+
stan_file <- write_stan_file("
28+
data {
29+
real val;
30+
}
31+
")
32+
expect_equal(jsonlite::read_json(process_data(data, stan_file = stan_file)), list(val = 5))
33+
stan_file <- write_stan_file("
34+
data {
35+
int val;
36+
}
37+
")
38+
expect_equal(jsonlite::read_json(process_data(data, stan_file = stan_file)), list(val = 5))
39+
stan_file <- write_stan_file("
40+
data {
41+
vector[1] val;
42+
}
43+
")
44+
expect_equal(jsonlite::read_json(process_data(data, stan_file = stan_file)), list(val = list(5)))
1245
})
1346

1447
test_that("process_fitted_params() works with basic input types", {
@@ -242,3 +275,59 @@ test_that("process_fitted_params() works with draws_matrix", {
242275
posterior::subset_draws(fit_params_tmp, variable = c("alpha", "beta[1]", "beta[2]", "beta[3]"))
243276
)
244277
})
278+
279+
test_that("process_data() errors on missing variables", {
280+
stan_file <- write_stan_file("
281+
data {
282+
real val1;
283+
real val2;
284+
}
285+
")
286+
expect_error(
287+
process_data(data = list(val1 = 5), stan_file = stan_file),
288+
"Missing input data for the following data variables: val2."
289+
)
290+
expect_error(
291+
process_data(data = list(val = 1), stan_file = stan_file),
292+
"Missing input data for the following data variables: val1, val2."
293+
)
294+
stan_file_no_data <- write_stan_file("
295+
transformed data {
296+
real val1 = 1;
297+
real val2 = 2;
298+
}
299+
")
300+
v <- process_data(data = list(val1 = 5), stan_file = stan_file_no_data)
301+
expect_type(v, "character")
302+
})
303+
304+
test_that("process_data() corrrectly casts integers and floating point numbers", {
305+
stan_file <- write_stan_file("
306+
data {
307+
int a;
308+
real b;
309+
}
310+
")
311+
test_file <- process_data(list(a = 1, b = 2), stan_file = stan_file)
312+
expect_match(
313+
" \"a\": 1,",
314+
readLines(test_file)[2],
315+
fixed = TRUE
316+
)
317+
expect_match(
318+
" \"b\": 2.0",
319+
readLines(test_file)[3],
320+
fixed = TRUE
321+
)
322+
test_file <- process_data(list(a = 1L, b = 1774000000), stan_file = stan_file)
323+
expect_match(
324+
" \"a\": 1,",
325+
readLines(test_file)[2],
326+
fixed = TRUE
327+
)
328+
expect_match(
329+
" \"b\": 1774000000.0",
330+
readLines(test_file)[3],
331+
fixed = TRUE
332+
)
333+
})

tests/testthat/test-failed-chains.R

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -158,22 +158,10 @@ test_that("init warnings are shown", {
158158

159159
test_that("optimize error on bad data", {
160160
mod <- testing_model("bernoulli")
161-
suppressWarnings(
162-
expect_message(
163-
mod$optimize(data = list(a = c(1,2,3)), seed = 123),
164-
"Exception: variable does not exist"
165-
)
166-
)
167-
expect_warning(
168-
utils::capture.output(
169-
fit <- mod$optimize(data = list(a = c(1,2,3)), seed = 123)
170-
),
171-
"Fitting finished unexpectedly!"
161+
expect_error(
162+
mod$optimize(data = list(a = c(1,2,3)), seed = 123),
163+
"Missing input data for the following data variables: N, y."
172164
)
173-
expect_error(fit$print(), "Fitting failed. Unable to print.")
174-
expect_error(fit$summary(), "Fitting failed. Unable to retrieve the draws.")
175-
expect_error(fit$draws(), "Fitting failed. Unable to retrieve the draws.")
176-
expect_error(fit$metadata(), "Fitting failed. Unable to retrieve the metadata.")
177165
})
178166

179167
test_that("errors when using draws after variational fais", {

tests/testthat/test-json.R

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,29 @@ test_that("write_stan_json() errors if bad names", {
182182
"All elements in 'data' list must have names"
183183
)
184184
})
185+
186+
test_that("write_stan_json() works with always_decimal = TRUE", {
187+
test_file <- tempfile(fileext = ".json")
188+
write_stan_json(list(a = 1L, b = 2), test_file, always_decimal = FALSE)
189+
expect_match(
190+
" \"a\": 1,",
191+
readLines(test_file)[2],
192+
fixed = TRUE
193+
)
194+
expect_match(
195+
" \"b\": 2",
196+
readLines(test_file)[3],
197+
fixed = TRUE
198+
)
199+
write_stan_json(list(a = 1L, b = 2), test_file, always_decimal = TRUE)
200+
expect_match(
201+
" \"a\": 1,",
202+
readLines(test_file)[2],
203+
fixed = TRUE
204+
)
205+
expect_match(
206+
" \"b\": 2.0",
207+
readLines(test_file)[3],
208+
fixed = TRUE
209+
)
210+
})

0 commit comments

Comments
 (0)