Skip to content

Commit a02159c

Browse files
authored
Merge pull request #528 from stan-dev/handle-tables-indata
Handle tables in data
2 parents 38ea435 + f7aa762 commit a02159c

File tree

9 files changed

+123
-24
lines changed

9 files changed

+123
-24
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ to gradients computed via finite differences. (#485)
2020
so that models do not get unnecessarily recompiled when calling the function
2121
multiple times with the same code. (#495, @martinmodrak)
2222

23+
* `write_stan_json()` now handles data of class `"table"`. Tables are converted
24+
to vector, matrix, or array depending on the dimensions of the table. (#528)
25+
2326
# cmdstanr 0.4.0
2427

2528
### Bug fixes

R/data.R

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#' * `logical` -> `integer` (`TRUE` -> `1`, `FALSE` -> `0`)
1212
#' * `data.frame` -> `matrix` (via [data.matrix()])
1313
#' * `list` -> `array`
14+
#' * `table` -> `vector`, `matrix`, or `array` (depending on dimensions of table)
1415
#'
1516
#' The `list` to `array` conversion is intended to make it easier to prepare
1617
#' the data for certain Stan declarations involving arrays:
@@ -52,11 +53,20 @@
5253
#' cat(readLines(file), sep = "\n")
5354
#'
5455
write_stan_json <- function(data, file) {
56+
if (!is.list(data)) {
57+
stop("'data' must be a list.", call. = FALSE)
58+
}
5559
if (!is.character(file) || !nzchar(file)) {
5660
stop("The supplied filename is invalid!", call. = FALSE)
5761
}
5862

5963
data_names <- names(data)
64+
if (length(data) > 0 &&
65+
(length(data_names) == 0 ||
66+
length(data_names) != sum(nzchar(data_names)))) {
67+
stop("All elements in 'data' list must have names.", call. = FALSE)
68+
69+
}
6070
if (anyDuplicated(data_names) != 0) {
6171
stop("Duplicate names not allowed in 'data'.", call. = FALSE)
6272
}
@@ -67,8 +77,13 @@ write_stan_json <- function(data, file) {
6777
is.data.frame(var) || is.list(var))) {
6878
stop("Variable '", var_name, "' is of invalid type.", call. = FALSE)
6979
}
80+
if (anyNA(var)) {
81+
stop("Variable '", var_name, "' has NA values.", call. = FALSE)
82+
}
7083

71-
if (is.logical(var)) {
84+
if (is.table(var)) {
85+
var <- unclass(var)
86+
} else if (is.logical(var)) {
7287
mode(var) <- "integer"
7388
} else if (is.data.frame(var)) {
7489
var <- data.matrix(var)
@@ -136,9 +151,6 @@ process_data <- function(data) {
136151
call. = FALSE
137152
)
138153
}
139-
if (any_na_elements(data)) {
140-
stop("Data includes NA values.", call. = FALSE)
141-
}
142154
path <- tempfile(pattern = "standata-", fileext = ".json")
143155
write_stan_json(data = data, file = path)
144156
} else {
@@ -153,12 +165,6 @@ any_zero_dims <- function(data) {
153165
any(has_zero_dims)
154166
}
155167

156-
# check if any objects in the data list contain NAs
157-
any_na_elements <- function(data) {
158-
has_na_elements <- sapply(data, anyNA)
159-
any(has_na_elements)
160-
}
161-
162168
#' Write posterior draws objects to csv files
163169
#' @noRd
164170
#' @param draws A `draws_array` from posterior pkg

man/write_stan_json.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
{
2+
"x": [
3+
[
4+
[5, 0, 0, 0],
5+
[0, 0, 0, 0],
6+
[0, 0, 0, 0],
7+
[0, 0, 0, 0]
8+
],
9+
[
10+
[0, 0, 0, 0],
11+
[0, 5, 0, 0],
12+
[0, 0, 0, 0],
13+
[0, 0, 0, 0]
14+
],
15+
[
16+
[0, 0, 0, 0],
17+
[0, 0, 0, 0],
18+
[0, 0, 5, 0],
19+
[0, 0, 0, 0]
20+
],
21+
[
22+
[0, 0, 0, 0],
23+
[0, 0, 0, 0],
24+
[0, 0, 0, 0],
25+
[0, 0, 0, 5]
26+
]
27+
]
28+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"x": [
3+
[5, 0, 0, 0],
4+
[0, 5, 0, 0],
5+
[0, 0, 5, 0],
6+
[0, 0, 0, 5]
7+
]
8+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"x": [5, 5, 5, 5]
3+
}

tests/testthat/test-data.R

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,6 @@ test_that("empty data list converted to NULL", {
1111
expect_null(process_data(list()))
1212
})
1313

14-
test_that("NAs detected in data list", {
15-
expect_false(any_na_elements(list(y = 1)))
16-
expect_true(any_na_elements(list(y = 1, N = NA)))
17-
expect_true(any_na_elements(list(x = matrix(NA, 1, 1))))
18-
expect_true(any_na_elements(list(x = list(1, NA))))
19-
})
20-
2114
test_that("process_fitted_params() works with basic input types", {
2215
temp_file <- tempfile()
2316
temp_files <- c(tempfile(),

tests/testthat/test-json.R

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,19 +76,62 @@ test_that("JSON output for list of matrices is correct", {
7676
file = test_path("answers", "json-matrix-lists.json"))
7777
})
7878

79-
test_that("write_stan_json() throws correct errors", {
80-
skip_on_cran()
79+
test_that("JSON output for table is correct", {
8180
temp_file <- tempfile()
81+
f <- factor(rep(1:4, each = 5))
8282

83+
write_stan_json(list(x = table(f)), file = temp_file)
84+
json_output <- readLines(temp_file)
85+
expect_known_output(cat(json_output, sep = "\n"),
86+
file = test_path("answers", "json-table-vector.json"))
87+
88+
write_stan_json(list(x = table(f, f)), file = temp_file)
89+
json_output <- readLines(temp_file)
90+
expect_known_output(cat(json_output, sep = "\n"),
91+
file = test_path("answers", "json-table-matrix.json"))
92+
93+
write_stan_json(list(x = table(f, f, f)), file = temp_file)
94+
json_output <- readLines(temp_file)
95+
expect_known_output(cat(json_output, sep = "\n"),
96+
file = test_path("answers", "json-table-array.json"))
97+
})
98+
99+
test_that("write_stan_json errors if NAs", {
100+
expect_error(
101+
write_stan_json(list(y = 1, N = NA), tempfile()),
102+
"Variable 'N' has NA values"
103+
)
104+
expect_error(
105+
write_stan_json(list(x = matrix(NA, 1, 1)), tempfile()),
106+
"Variable 'x' has NA values"
107+
)
83108
expect_error(
84-
write_stan_json(list(N = c(1.0, 2.0, 3, 4)), file = c(1,2)),
109+
write_stan_json(list(x = list(1, NA)), tempfile()),
110+
"Variable 'x' has NA values"
111+
)
112+
})
113+
114+
test_that("write_stan_json() errors if data is not a list", {
115+
expect_error(
116+
write_stan_json(1:10),
117+
"'data' must be a list"
118+
)
119+
})
120+
121+
test_that("write_stan_json() errors if bad filename", {
122+
temp_file <- tempfile()
123+
124+
expect_error(
125+
write_stan_json(list(N = 10), file = c(1,2)),
85126
"The supplied filename is invalid!"
86127
)
87128
expect_error(
88-
write_stan_json(list(N = N), file = ""),
129+
write_stan_json(list(N = 10), file = ""),
89130
"The supplied filename is invalid!"
90131
)
132+
})
91133

134+
test_that("write_stan_json() errors if vectors/matrices in same list are different sizes", {
92135
expect_error(
93136
write_stan_json(list(N = list(c(26, 26, 26), c(26, 26))), file = "abc.txt"),
94137
"All matrices/vectors in list 'N' must be the same size!"
@@ -109,7 +152,9 @@ test_that("write_stan_json() throws correct errors", {
109152
write_stan_json(list(N = list(matrix(1:8, ncol = 2), matrix(1:9, ncol = 3))), file = "abc.txt"),
110153
"All matrices/vectors in list 'N' must be the same size!"
111154
)
155+
})
112156

157+
test_that("write_stan_json() errors if invalid types", {
113158
expect_error(
114159
write_stan_json(list(N = list("abc", "def")), file = "abc.txt"),
115160
"All elements in list 'N' must be numeric!"
@@ -119,9 +164,21 @@ test_that("write_stan_json() throws correct errors", {
119164
write_stan_json(list(N = "STRING"), file = "abc.txt"),
120165
"Variable 'N' is of invalid type"
121166
)
167+
})
122168

169+
test_that("write_stan_json() errors if bad names", {
123170
expect_error(
124171
write_stan_json(list(x = 1, y = 2, x = 3), file = tempfile()),
125172
"Duplicate names not allowed in 'data'"
126173
)
174+
175+
expect_error(
176+
write_stan_json(list(1, 2), tempfile()),
177+
"All elements in 'data' list must have names"
178+
)
179+
180+
expect_error(
181+
write_stan_json(list(a = 1, 2), tempfile()),
182+
"All elements in 'data' list must have names"
183+
)
127184
})

tests/testthat/test-model-data.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ test_that("error if data contains NA elements", {
4242
data_list2$y[3] <- NA
4343
data_list3$X[3, 2] <- NA
4444

45-
expect_error(mod$sample(data = data_list1), "Data includes NA values")
46-
expect_error(mod$sample(data = data_list2), "Data includes NA values")
47-
expect_error(mod$sample(data = data_list3), "Data includes NA values")
45+
expect_error(mod$sample(data = data_list1), "Variable 'N' has NA values")
46+
expect_error(mod$sample(data = data_list2), "Variable 'y' has NA values")
47+
expect_error(mod$sample(data = data_list3), "Variable 'X' has NA values")
4848
})
4949

5050
test_that("empty data list doesn't error if no data block", {

0 commit comments

Comments
 (0)