@@ -4,6 +4,9 @@ set_cmdstan_path()
4
4
mod <- testing_model(" bernoulli" )
5
5
data_list <- testing_data(" bernoulli" )
6
6
7
+ mod2 <- testing_model(" logistic" )
8
+ data_list2 <- testing_data(" logistic" )
9
+
7
10
8
11
test_that(" sample() method works with provided inv_metrics" , {
9
12
inv_metric_vector <- array (1 , dim = c(1 ))
@@ -54,7 +57,7 @@ test_that("sample() method works with provided inv_metrics", {
54
57
})
55
58
56
59
57
- test_that(" sample() method works with inv_metrics extracted from previous fit" , {
60
+ test_that(" sample() method works with inv_metrics extracted from previous fit with 1 parameter " , {
58
61
expect_sample_output(fit_r <- mod $ sample(data = data_list ,
59
62
chains = 2 ,
60
63
seed = 123 ))
@@ -89,6 +92,41 @@ test_that("sample() method works with inv_metrics extracted from previous fit",
89
92
seed = 123 )))
90
93
})
91
94
95
+ test_that(" sample() method works with inv_metrics extracted from previous fit with > 1 parameter" , {
96
+ expect_sample_output(fit_r <- mod2 $ sample(data = data_list2 ,
97
+ chains = 2 ,
98
+ seed = 123 ))
99
+ inv_metric_vector <- fit_r $ inv_metric(matrix = FALSE )
100
+ inv_metric_matrix <- fit_r $ inv_metric()
101
+
102
+ expect_equal(length(inv_metric_vector [[1 ]]), 4 )
103
+ expect_equal(dim(inv_metric_matrix [[1 ]]), c(4 , 4 ))
104
+
105
+ expect_silent(expect_sample_output(fit_r <- mod2 $ sample(data = data_list2 ,
106
+ chains = 1 ,
107
+ metric = " diag_e" ,
108
+ inv_metric = inv_metric_vector [[1 ]],
109
+ seed = 123 )))
110
+
111
+ expect_silent(expect_sample_output(fit_r <- mod2 $ sample(data = data_list2 ,
112
+ chains = 1 ,
113
+ metric = " dense_e" ,
114
+ inv_metric = inv_metric_matrix [[1 ]],
115
+ seed = 123 )))
116
+
117
+ expect_silent(expect_sample_output(fit_r <- mod2 $ sample(data = data_list2 ,
118
+ chains = 2 ,
119
+ metric = " diag_e" ,
120
+ inv_metric = inv_metric_vector ,
121
+ seed = 123 )))
122
+
123
+ expect_silent(expect_sample_output(fit_r <- mod2 $ sample(data = data_list2 ,
124
+ chains = 2 ,
125
+ metric = " dense_e" ,
126
+ inv_metric = inv_metric_matrix ,
127
+ seed = 123 )))
128
+ })
129
+
92
130
93
131
test_that(" sample() method works with lists of inv_metrics" , {
94
132
inv_metric_vector <- array (1 , dim = c(1 ))
0 commit comments