@@ -25,48 +25,32 @@ test_that("process_fitted_params() works with basic input types", {
25
25
})
26
26
27
27
test_that(" process_fitted_params() errors with bad args" , {
28
+ error_msg <- " 'fitted_params' must be a list of paths to CSV files, a CmdStanMCMC/CmdStanVB object, a posterior::draws_array or a posterior::draws_matrix."
28
29
expect_error(
29
30
process_fitted_params(5 ),
30
- " 'fitted_params' should be a vector of paths or a CmdStanMCMC object. "
31
+ error_msg
31
32
)
32
33
expect_error(
33
34
process_fitted_params(NULL ),
34
- " 'fitted_params' should be a vector of paths or a CmdStanMCMC object."
35
- )
36
- expect_error(
37
- process_fitted_params(fit_vb ),
38
- " 'fitted_params' should be a vector of paths or a CmdStanMCMC object."
35
+ error_msg
39
36
)
40
37
expect_error(
41
38
process_fitted_params(fit_optimize ),
42
- " 'fitted_params' should be a vector of paths or a CmdStanMCMC object."
43
- )
44
-
45
- fit_tmp <- testing_fit(" bernoulli" , method = " sample" , seed = 123 )
46
- temp_file <- tempfile(fileext = " .rds" )
47
- saveRDS(fit_tmp , file = temp_file )
48
- rm(fit_tmp )
49
- gc()
50
- fit_tmp_null <- readRDS(temp_file )
51
- expect_error(
52
- process_fitted_params(fit_tmp_null ),
53
- " Unable to obtain draws from the fit \\ (CmdStanMCMC\\ ) object."
39
+ error_msg
54
40
)
55
41
56
42
fit_tmp <- testing_fit(" bernoulli" , method = " sample" , seed = 123 )
57
- fit_tmp $ draws()
58
43
temp_file <- tempfile(fileext = " .rds" )
59
44
saveRDS(fit_tmp , file = temp_file )
60
45
rm(fit_tmp )
61
46
gc()
62
47
fit_tmp_null <- readRDS(temp_file )
63
48
expect_error(
64
49
process_fitted_params(fit_tmp_null ),
65
- " Unable to obtain sampler diagnostics from the fit \\ (CmdStanMCMC \\ ) object."
50
+ " Unable to obtain draws from the fit object."
66
51
)
67
52
})
68
53
69
-
70
54
test_that(" process_fitted_params() works if output_files in fit do not exist" , {
71
55
fit_ref <- testing_fit(" bernoulli" , method = " sample" , seed = 123 )
72
56
fit_tmp <- testing_fit(" bernoulli" , method = " sample" , seed = 123 )
@@ -108,4 +92,118 @@ test_that("process_fitted_params() works if output_files in fit do not exist", {
108
92
}
109
93
})
110
94
95
+ test_that(" process_fitted_params() works with CmdStanMCMC" , {
96
+ fit <- testing_fit(" logistic" , method = " sample" , seed = 123 )
97
+ fit_params_files <- process_fitted_params(fit )
98
+ expect_true(all(file.exists(fit_params_files )))
99
+ chain <- 1
100
+ for (file in fit_params_files ) {
101
+ if (os_is_windows()) {
102
+ grep_path <- repair_path(Sys.which(" grep.exe" ))
103
+ fread_cmd <- paste0(grep_path , " -v '^#' --color=never " , file )
104
+ } else {
105
+ fread_cmd <- paste0(" grep -v '^#' --color=never " , file )
106
+ }
107
+ suppressWarnings(
108
+ fit_params_tmp <- data.table :: fread(
109
+ cmd = fread_cmd
110
+ )
111
+ )
112
+ fit_params_tmp <- posterior :: as_draws_array(fit_params_tmp )
113
+ posterior :: variables(fit_params_tmp ) <- repair_variable_names(posterior :: variables(fit_params_tmp ))
114
+ expect_equal(
115
+ posterior :: subset_draws(fit $ draws(), variable = " lp__" , chain = chain ),
116
+ posterior :: subset_draws(fit_params_tmp , variable = " lp__" )
117
+ )
118
+ expect_equal(
119
+ posterior :: subset_draws(fit $ draws(), variable = c(" alpha" , " beta[1]" , " beta[2]" , " beta[3]" ), chain = chain ),
120
+ posterior :: subset_draws(fit_params_tmp , variable = c(" alpha" , " beta[1]" , " beta[2]" , " beta[3]" ),)
121
+ )
122
+ chain <- chain + 1
123
+ }
124
+ })
111
125
126
+ test_that(" process_fitted_params() works with draws_array" , {
127
+ fit <- testing_fit(" logistic" , method = " sample" , seed = 123 )
128
+ fit_params_files <- process_fitted_params(fit $ draws())
129
+ expect_true(all(file.exists(fit_params_files )))
130
+ chain <- 1
131
+ for (file in fit_params_files ) {
132
+ if (os_is_windows()) {
133
+ grep_path <- repair_path(Sys.which(" grep.exe" ))
134
+ fread_cmd <- paste0(grep_path , " -v '^#' --color=never " , file )
135
+ } else {
136
+ fread_cmd <- paste0(" grep -v '^#' --color=never " , file )
137
+ }
138
+ suppressWarnings(
139
+ fit_params_tmp <- data.table :: fread(
140
+ cmd = fread_cmd
141
+ )
142
+ )
143
+ fit_params_tmp <- posterior :: as_draws_array(fit_params_tmp )
144
+ posterior :: variables(fit_params_tmp ) <- repair_variable_names(posterior :: variables(fit_params_tmp ))
145
+ expect_equal(
146
+ posterior :: subset_draws(fit $ draws(), variable = " lp__" , chain = chain ),
147
+ posterior :: subset_draws(fit_params_tmp , variable = " lp__" )
148
+ )
149
+ expect_equal(
150
+ posterior :: subset_draws(fit $ draws(), variable = c(" alpha" , " beta[1]" , " beta[2]" , " beta[3]" ), chain = chain ),
151
+ posterior :: subset_draws(fit_params_tmp , variable = c(" alpha" , " beta[1]" , " beta[2]" , " beta[3]" ),)
152
+ )
153
+ chain <- chain + 1
154
+ }
155
+ })
156
+
157
+ test_that(" process_fitted_params() works with CmdStanVB" , {
158
+ fit <- testing_fit(" logistic" , method = " variational" , seed = 123 )
159
+ file <- process_fitted_params(fit )
160
+ expect_true(file.exists(file ))
161
+ if (os_is_windows()) {
162
+ grep_path <- repair_path(Sys.which(" grep.exe" ))
163
+ fread_cmd <- paste0(grep_path , " -v '^#' --color=never " , file )
164
+ } else {
165
+ fread_cmd <- paste0(" grep -v '^#' --color=never " , file )
166
+ }
167
+ suppressWarnings(
168
+ fit_params_tmp <- data.table :: fread(
169
+ cmd = fread_cmd
170
+ )
171
+ )
172
+ fit_params_tmp <- posterior :: as_draws_array(fit_params_tmp )
173
+ posterior :: variables(fit_params_tmp ) <- repair_variable_names(posterior :: variables(fit_params_tmp ))
174
+ expect_equal(
175
+ posterior :: subset_draws(posterior :: as_draws_array(fit $ draws()), variable = " lp__" ),
176
+ posterior :: subset_draws(fit_params_tmp , variable = " lp__" )
177
+ )
178
+ expect_equal(
179
+ posterior :: subset_draws(posterior :: as_draws_array(fit $ draws()), variable = c(" alpha" , " beta[1]" , " beta[2]" , " beta[3]" )),
180
+ posterior :: subset_draws(fit_params_tmp , variable = c(" alpha" , " beta[1]" , " beta[2]" , " beta[3]" ))
181
+ )
182
+ })
183
+
184
+ test_that(" process_fitted_params() works with draws_matrix" , {
185
+ fit <- testing_fit(" logistic" , method = " variational" , seed = 123 )
186
+ file <- process_fitted_params(fit $ draws())
187
+ expect_true(file.exists(file ))
188
+ if (os_is_windows()) {
189
+ grep_path <- repair_path(Sys.which(" grep.exe" ))
190
+ fread_cmd <- paste0(grep_path , " -v '^#' --color=never " , file )
191
+ } else {
192
+ fread_cmd <- paste0(" grep -v '^#' --color=never " , file )
193
+ }
194
+ suppressWarnings(
195
+ fit_params_tmp <- data.table :: fread(
196
+ cmd = fread_cmd
197
+ )
198
+ )
199
+ fit_params_tmp <- posterior :: as_draws_array(fit_params_tmp )
200
+ posterior :: variables(fit_params_tmp ) <- repair_variable_names(posterior :: variables(fit_params_tmp ))
201
+ expect_equal(
202
+ posterior :: subset_draws(posterior :: as_draws_array(fit $ draws()), variable = " lp__" ),
203
+ posterior :: subset_draws(fit_params_tmp , variable = " lp__" )
204
+ )
205
+ expect_equal(
206
+ posterior :: subset_draws(posterior :: as_draws_array(fit $ draws()), variable = c(" alpha" , " beta[1]" , " beta[2]" , " beta[3]" )),
207
+ posterior :: subset_draws(fit_params_tmp , variable = c(" alpha" , " beta[1]" , " beta[2]" , " beta[3]" ))
208
+ )
209
+ })
0 commit comments