Skip to content

Commit ae1b7b3

Browse files
authored
Merge pull request #935 from venpopov/inv_metric_1par
Fix incorrect format of inv_metric when only 1 parameter in model
2 parents 82f9d9a + c69ef19 commit ae1b7b3

File tree

3 files changed

+81
-1
lines changed

3 files changed

+81
-1
lines changed

R/args.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,9 @@ SampleArgs <- R6::R6Class(
254254
fileext = ".json"
255255
)
256256
for (i in seq_along(inv_metric_paths)) {
257+
if (length(inv_metric[[i]]) == 1 && metric == "diag_e") {
258+
inv_metric[[i]] <- array(inv_metric[[i]], dim = c(1))
259+
}
257260
write_stan_json(list(inv_metric = inv_metric[[i]]), inv_metric_paths[i])
258261
}
259262

R/fit.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1742,7 +1742,10 @@ inv_metric <- function(matrix = TRUE) {
17421742
out <- private$inv_metric_
17431743
if (matrix && !is.matrix(out[[1]])) {
17441744
# convert each vector to a diagonal matrix
1745-
out <- lapply(out, diag)
1745+
out <- lapply(out, function(x) diag(x, nrow = length(x)))
1746+
} else if (length(out[[1]]) == 1) {
1747+
# convert each scalar to an array with dimension 1
1748+
out <- lapply(out, array, dim = c(1))
17461749
}
17471750
out
17481751
}

tests/testthat/test-model-sample-metric.R

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ set_cmdstan_path()
44
mod <- testing_model("bernoulli")
55
data_list <- testing_data("bernoulli")
66

7+
mod2 <- testing_model("logistic")
8+
data_list2 <- testing_data("logistic")
9+
710

811
test_that("sample() method works with provided inv_metrics", {
912
inv_metric_vector <- array(1, dim = c(1))
@@ -54,6 +57,77 @@ test_that("sample() method works with provided inv_metrics", {
5457
})
5558

5659

60+
test_that("sample() method works with inv_metrics extracted from previous fit with 1 parameter", {
61+
expect_sample_output(fit_r <- mod$sample(data = data_list,
62+
chains = 2,
63+
seed = 123))
64+
inv_metric_vector <- fit_r$inv_metric(matrix = FALSE)
65+
inv_metric_matrix <- fit_r$inv_metric()
66+
67+
expect_equal(dim(inv_metric_vector[[1]]), 1)
68+
expect_equal(dim(inv_metric_matrix[[1]]), c(1, 1))
69+
70+
expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list,
71+
chains = 1,
72+
metric = "diag_e",
73+
inv_metric = inv_metric_vector[[1]],
74+
seed = 123)))
75+
76+
expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list,
77+
chains = 1,
78+
metric = "dense_e",
79+
inv_metric = inv_metric_matrix[[1]],
80+
seed = 123)))
81+
82+
expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list,
83+
chains = 2,
84+
metric = "diag_e",
85+
inv_metric = inv_metric_vector,
86+
seed = 123)))
87+
88+
expect_silent(expect_sample_output(fit_r <- mod$sample(data = data_list,
89+
chains = 2,
90+
metric = "dense_e",
91+
inv_metric = inv_metric_matrix,
92+
seed = 123)))
93+
})
94+
95+
test_that("sample() method works with inv_metrics extracted from previous fit with > 1 parameter", {
96+
expect_sample_output(fit_r <- mod2$sample(data = data_list2,
97+
chains = 2,
98+
seed = 123))
99+
inv_metric_vector <- fit_r$inv_metric(matrix = FALSE)
100+
inv_metric_matrix <- fit_r$inv_metric()
101+
102+
expect_equal(length(inv_metric_vector[[1]]), 4)
103+
expect_equal(dim(inv_metric_matrix[[1]]), c(4, 4))
104+
105+
expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
106+
chains = 1,
107+
metric = "diag_e",
108+
inv_metric = inv_metric_vector[[1]],
109+
seed = 123)))
110+
111+
expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
112+
chains = 1,
113+
metric = "dense_e",
114+
inv_metric = inv_metric_matrix[[1]],
115+
seed = 123)))
116+
117+
expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
118+
chains = 2,
119+
metric = "diag_e",
120+
inv_metric = inv_metric_vector,
121+
seed = 123)))
122+
123+
expect_silent(expect_sample_output(fit_r <- mod2$sample(data = data_list2,
124+
chains = 2,
125+
metric = "dense_e",
126+
inv_metric = inv_metric_matrix,
127+
seed = 123)))
128+
})
129+
130+
57131
test_that("sample() method works with lists of inv_metrics", {
58132
inv_metric_vector <- array(1, dim = c(1))
59133
inv_metric_vector_json <- test_path("resources", "metric", "bernoulli.inv_metric.diag_e.json")

0 commit comments

Comments
 (0)