Skip to content

Commit daa111d

Browse files
committed
fix empty matrix inv_metric with 1 parameter
1 parent 69135f4 commit daa111d

File tree

3 files changed

+42
-4
lines changed

3 files changed

+42
-4
lines changed

R/args.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ SampleArgs <- R6::R6Class(
254254
fileext = ".json"
255255
)
256256
for (i in seq_along(inv_metric_paths)) {
257-
if (length(inv_metric[[i]] == 1) && metric == "diag_e") {
257+
if (length(inv_metric[[i]]) == 1 && metric == "diag_e") {
258258
inv_metric[[i]] <- array(inv_metric[[i]], dim = c(1))
259259
}
260260
write_stan_json(list(inv_metric = inv_metric[[i]]), inv_metric_paths[i])

R/fit.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,9 +1742,9 @@ inv_metric <- function(matrix = TRUE) {
17421742
out <- private$inv_metric_
17431743
if (matrix && !is.matrix(out[[1]])) {
17441744
# convert each vector to a diagonal matrix
1745-
out <- lapply(out, diag)
1745+
out <- lapply(out, function(x) diag(x, nrow = length(x)))
17461746
} else if (length(out[[1]]) == 1) {
1747-
# convert each scalar to a 1x1 matrix
1747+
# convert each scalar to an array with dimension 1
17481748
out <- lapply(out, array, dim = c(1))
17491749
}
17501750
out

tests/testthat/test-model-sample-metric.R

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ set_cmdstan_path()
44
mod <- testing_model("bernoulli")
55
data_list <- testing_data("bernoulli")
66

7+
mod2 <- testing_model("logistic")
8+
data_list2 <- testing_data("logistic")
9+
710

811
test_that("sample() method works with provided inv_metrics", {
912
inv_metric_vector <- array(1, dim = c(1))
@@ -54,7 +57,7 @@ test_that("sample() method works with provided inv_metrics", {
5457
})
5558

5659

57-
test_that("sample() method works with inv_metrics extracted from previous fit", {
60+
test_that("sample() method works with inv_metrics extracted from previous fit with 1 parameter", {
5861
expect_sample_output(fit_r <- mod$sample(data = data_list,
5962
chains = 2,
6063
seed = 123))
@@ -89,6 +92,41 @@ test_that("sample() method works with inv_metrics extracted from previous fit",
8992
seed = 123)))
9093
})
9194

95+
test_that("sample() method works with inv_metrics extracted from previous fit with > 1 parameter", {
96+
expect_sample_output(fit_r <- mod2$sample(data = data_list2,
97+
chains = 2,
98+
seed = 123))
99+
inv_metric_vector <- fit_r$inv_metric(matrix = FALSE)
100+
inv_metric_matrix <- fit_r$inv_metric()
101+
102+
expect_equal(length(inv_metric_vector[[1]]), 4)
103+
expect_equal(dim(inv_metric_matrix[[1]]), c(4, 4))
104+
105+
expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
106+
chains = 1,
107+
metric = "diag_e",
108+
inv_metric = inv_metric_vector[[1]],
109+
seed = 123)))
110+
111+
expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
112+
chains = 1,
113+
metric = "dense_e",
114+
inv_metric = inv_metric_matrix[[1]],
115+
seed = 123)))
116+
117+
expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
118+
chains = 2,
119+
metric = "diag_e",
120+
inv_metric = inv_metric_vector,
121+
seed = 123)))
122+
123+
expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
124+
chains = 2,
125+
metric = "dense_e",
126+
inv_metric = inv_metric_matrix,
127+
seed = 123)))
128+
})
129+
92130

93131
test_that("sample() method works with lists of inv_metrics", {
94132
inv_metric_vector <- array(1, dim = c(1))

0 commit comments

Comments
 (0)