Skip to content

Commit 5a02481

Browse files
authored
Merge pull request #394 from stan-dev/fix_inv_metric_scientific_notation
Fix inv_metric()
2 parents 075a595 + 99bb7ef commit 5a02481

File tree

4 files changed

+28
-20
lines changed

4 files changed

+28
-20
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
### Bug fixes
44

5+
* Fixed reading inverse mass matrix with values written in scientific format in
6+
the CSV. (#394)
7+
58
### New features
69

710
* Added `$sample_mpi()` for MCMC sampling with MPI. (#350)

R/read_csv.R

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,11 @@ read_csv_metadata <- function(csv_file) {
388388
inv_metric_next <- FALSE
389389
inv_metric_diagonal_next <- FALSE
390390
csv_file_info <- list()
391-
inv_metric_rows <- 0
391+
csv_file_info$inv_metric <- NULL
392+
inv_metric_rows_to_read <- -1
393+
inv_metric_rows <- -1
392394
parsing_done <- FALSE
395+
dense_inv_metric <- FALSE
393396
if (os_is_windows()) {
394397
grep_path <- repair_path(Sys.which("grep.exe"))
395398
fread_cmd <- paste0(grep_path, " '^[#a-zA-Z]' --color=never ", csv_file)
@@ -422,26 +425,28 @@ read_csv_metadata <- function(csv_file) {
422425
}
423426
} else {
424427
parse_key_val <- TRUE
425-
if (grepl("# Diagonal elements of inverse mass matrix:", line, perl = TRUE)
426-
|| grepl("# Elements of inverse mass matrix:", line, perl = TRUE)) {
428+
if (grepl("# Diagonal elements of inverse mass matrix:", line, perl = TRUE)) {
427429
inv_metric_next <- TRUE
428430
parse_key_val <- FALSE
431+
inv_metric_rows <- 1
432+
inv_metric_rows_to_read <- 1
433+
dense_inv_metric <- FALSE
434+
} else if (grepl("# Elements of inverse mass matrix:", line, perl = TRUE)) {
435+
inv_metric_next <- TRUE
436+
parse_key_val <- FALSE
437+
dense_inv_metric <- TRUE
429438
} else if (inv_metric_next) {
430439
inv_metric_split <- strsplit(gsub("# ", "", line), ",")
431-
if ((length(inv_metric_split) == 0) ||
432-
((length(inv_metric_split) == 1) && identical(inv_metric_split[[1]], character(0))) ||
433-
grepl("[a-zA-z]", line, perl = TRUE) ||
434-
inv_metric_split == "#") {
435-
parsing_done <- TRUE
436-
parse_key_val <- TRUE
437-
break;
440+
numeric_inv_metric_split <- rapply(inv_metric_split, as.numeric)
441+
if (inv_metric_rows == -1 && dense_inv_metric) {
442+
inv_metric_rows <- length(inv_metric_split[[1]])
443+
inv_metric_rows_to_read <- inv_metric_rows
438444
}
439-
if (inv_metric_rows == 0) {
440-
csv_file_info$inv_metric <- rapply(inv_metric_split, as.numeric)
441-
} else {
442-
csv_file_info$inv_metric <- c(csv_file_info$inv_metric, rapply(inv_metric_split, as.numeric))
445+
csv_file_info$inv_metric <- c(csv_file_info$inv_metric, numeric_inv_metric_split)
446+
inv_metric_rows_to_read <- inv_metric_rows_to_read - 1
447+
if (inv_metric_rows_to_read == 0) {
448+
inv_metric_next <- FALSE
443449
}
444-
inv_metric_rows <- inv_metric_rows + 1
445450
parse_key_val <- FALSE
446451
}
447452
if (parse_key_val) {

tests/testthat/resources/csv/model1-1-warmup.csv

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ lp__,accept_stat__,stepsize__,treedepth__,n_leapfrog__,divergent__,energy__,mu,s
140140
# Adaptation terminated
141141
# Step size = 0.712907
142142
# Diagonal elements of inverse mass matrix:
143-
# 1.00098, 0.068748
143+
# 1.00098, 0.068748e-2
144144
-19.4938,0.953779,0.712907,2,3,0,19.4971,8.11498,7.4563
145145
-19.6889,0.983261,0.712907,1,1,0,19.8364,7.96487,7.78375
146146
-18.0516,0.982462,0.712907,2,3,0,20.5179,8.24821,5.4579
@@ -241,8 +241,8 @@ lp__,accept_stat__,stepsize__,treedepth__,n_leapfrog__,divergent__,energy__,mu,s
241241
-13.3724,1,0.712907,1,1,0,13.6117,5.58052,2.40945
242242
-13.3724,0.292728,0.712907,2,3,0,17.7528,5.58052,2.40945
243243
-13.348,0.991998,0.712907,2,3,0,13.5989,4.34492,2.68262
244-
#
244+
#
245245
# Elapsed Time: 0.038029 seconds (Warm-up)
246246
# 0.030711 seconds (Sampling)
247247
# 0.06874 seconds (Total)
248-
#
248+
#

tests/testthat/test-csv.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ test_that("read_cmdstan_csv() returns correct diagonal of inverse mass matrix",
160160
csv_files <- c(test_path("resources", "csv", "model1-1-warmup.csv"),test_path("resources", "csv", "model1-2-warmup.csv"))
161161
csv_output <- read_cmdstan_csv(csv_files)
162162
expect_equal(as.vector(csv_output$inv_metric[[as.character(1)]]),
163-
c(1.00098, 0.068748))
163+
c(1.00098, 0.00068748))
164164
expect_equal(as.vector(csv_output$inv_metric[[as.character(2)]]),
165165
c(0.909635, 0.066384))
166166
})
@@ -296,7 +296,7 @@ test_that("read_cmdstan_csv() reads values up to adaptation", {
296296

297297
csv_out <- read_cmdstan_csv(csv_files)
298298
expect_equal(csv_out$metadata$pi, 3.14)
299-
expect_true(is.null(csv_out$metadata$pi_square))
299+
expect_false(is.null(csv_out$metadata$pi_square))
300300
})
301301

302302
test_that("remaining_columns_to_read() works", {

0 commit comments

Comments
 (0)