@@ -4,6 +4,9 @@ set_cmdstan_path()
4
4
mod <- testing_model(" bernoulli" )
5
5
data_list <- testing_data(" bernoulli" )
6
6
7
+ mod2 <- testing_model(" logistic" )
8
+ data_list2 <- testing_data(" logistic" )
9
+
7
10
8
11
test_that(" sample() method works with provided inv_metrics" , {
9
12
inv_metric_vector <- array (1 , dim = c(1 ))
@@ -54,6 +57,77 @@ test_that("sample() method works with provided inv_metrics", {
54
57
})
55
58
56
59
60
+ test_that(" sample() method works with inv_metrics extracted from previous fit with 1 parameter" , {
61
+ expect_sample_output(fit_r <- mod $ sample(data = data_list ,
62
+ chains = 2 ,
63
+ seed = 123 ))
64
+ inv_metric_vector <- fit_r $ inv_metric(matrix = FALSE )
65
+ inv_metric_matrix <- fit_r $ inv_metric()
66
+
67
+ expect_equal(dim(inv_metric_vector [[1 ]]), 1 )
68
+ expect_equal(dim(inv_metric_matrix [[1 ]]), c(1 , 1 ))
69
+
70
+ expect_silent(expect_sample_output(fit_r <- mod $ sample(data = data_list ,
71
+ chains = 1 ,
72
+ metric = " diag_e" ,
73
+ inv_metric = inv_metric_vector [[1 ]],
74
+ seed = 123 )))
75
+
76
+ expect_silent(expect_sample_output(fit_r <- mod $ sample(data = data_list ,
77
+ chains = 1 ,
78
+ metric = " dense_e" ,
79
+ inv_metric = inv_metric_matrix [[1 ]],
80
+ seed = 123 )))
81
+
82
+ expect_silent(expect_sample_output(fit_r <- mod $ sample(data = data_list ,
83
+ chains = 2 ,
84
+ metric = " diag_e" ,
85
+ inv_metric = inv_metric_vector ,
86
+ seed = 123 )))
87
+
88
+ expect_silent(expect_sample_output(fit_r <- mod $ sample(data = data_list ,
89
+ chains = 2 ,
90
+ metric = " dense_e" ,
91
+ inv_metric = inv_metric_matrix ,
92
+ seed = 123 )))
93
+ })
94
+
95
+ test_that(" sample() method works with inv_metrics extracted from previous fit with > 1 parameter" , {
96
+ expect_sample_output(fit_r <- mod2 $ sample(data = data_list2 ,
97
+ chains = 2 ,
98
+ seed = 123 ))
99
+ inv_metric_vector <- fit_r $ inv_metric(matrix = FALSE )
100
+ inv_metric_matrix <- fit_r $ inv_metric()
101
+
102
+ expect_equal(length(inv_metric_vector [[1 ]]), 4 )
103
+ expect_equal(dim(inv_metric_matrix [[1 ]]), c(4 , 4 ))
104
+
105
+ expect_silent(expect_sample_output(fit_r <- mod2 $ sample(data = data_list2 ,
106
+ chains = 1 ,
107
+ metric = " diag_e" ,
108
+ inv_metric = inv_metric_vector [[1 ]],
109
+ seed = 123 )))
110
+
111
+ expect_silent(expect_sample_output(fit_r <- mod2 $ sample(data = data_list2 ,
112
+ chains = 1 ,
113
+ metric = " dense_e" ,
114
+ inv_metric = inv_metric_matrix [[1 ]],
115
+ seed = 123 )))
116
+
117
+ expect_silent(expect_sample_output(fit_r <- mod2 $ sample(data = data_list2 ,
118
+ chains = 2 ,
119
+ metric = " diag_e" ,
120
+ inv_metric = inv_metric_vector ,
121
+ seed = 123 )))
122
+
123
+ expect_silent(expect_sample_output(fit_r <- mod2 $ sample(data = data_list2 ,
124
+ chains = 2 ,
125
+ metric = " dense_e" ,
126
+ inv_metric = inv_metric_matrix ,
127
+ seed = 123 )))
128
+ })
129
+
130
+
57
131
test_that(" sample() method works with lists of inv_metrics" , {
58
132
inv_metric_vector <- array (1 , dim = c(1 ))
59
133
inv_metric_vector_json <- test_path(" resources" , " metric" , " bernoulli.inv_metric.diag_e.json" )
0 commit comments