diff --git a/NEWS.md b/NEWS.md index 4d9a3af0..71133fe0 100644 --- a/NEWS.md +++ b/NEWS.md @@ -11,6 +11,8 @@ * `loo_compare()` output now includes additional columns: `p_worse`, `diag_diff`, and `diag_elpd`, providing richer diagnostics for model comparison by @florence-bockting in #300 +* `print.compare.loo()` regains a `simplify = FALSE` mode for showing the full + comparison table, including the available estimate and standard-error columns # loo 2.9.0 diff --git a/R/loo_compare.R b/R/loo_compare.R index ce5bea4a..80202507 100644 --- a/R/loo_compare.R +++ b/R/loo_compare.R @@ -92,6 +92,9 @@ #' comp <- loo_compare(loo1, loo2, loo3) #' print(comp, digits = 2) #' +#' # print full table with pointwise ELPD and LOOIC +#' print(comp, simplify = FALSE) +#' #' # can use a list of objects with custom names #' # the names will be used in the output #' loo_compare(list("apple" = loo1, "banana" = loo2, "cherry" = loo3)) @@ -170,7 +173,11 @@ loo_compare.default <- function(x, ...) { #' @param p_worse For the print method only, should we include the normal #' approximation based probability of each model having worse performance than #' the best model? The default is `TRUE`. -print.compare.loo <- function(x, ..., digits = 1, p_worse = TRUE) { +#' @param simplify For the print method only, should the output be simplified +#' to only include the model names and ELPD differences? The default is +#' `TRUE`. If `FALSE`, the full comparison table is printed including +#' pointwise ELPD, LOOIC/WAIC, and their standard errors for each model. +print.compare.loo <- function(x, ..., digits = 1, p_worse = TRUE, simplify = TRUE) { if (inherits(x, "old_compare.loo")) { return(unclass(x)) } @@ -193,6 +200,16 @@ print.compare.loo <- function(x, ..., digits = 1, p_worse = TRUE) { diag_elpd = x[, "diag_elpd"] ) } + if (!simplify) { + est_cols <- c("elpd_loo", "se_elpd_loo", "p_loo", "se_p_loo", + "looic", "se_looic", + "elpd_waic", "se_elpd_waic", "p_waic", "se_p_waic", + "waic", "se_waic") + avail <- intersect(est_cols, colnames(x)) + if (length(avail)) { + x2 <- cbind(x2, .fr(x[, avail], digits)) + } + } print(x2, quote = FALSE, row.names = FALSE) # show glossary for diagnostic flags diff --git a/man/loo_compare.Rd b/man/loo_compare.Rd index c54b9f86..361b52c2 100644 --- a/man/loo_compare.Rd +++ b/man/loo_compare.Rd @@ -12,7 +12,7 @@ loo_compare(x, ...) \method{loo_compare}{default}(x, ...) -\method{print}{compare.loo}(x, ..., digits = 1, p_worse = TRUE) +\method{print}{compare.loo}(x, ..., digits = 1, p_worse = TRUE, simplify = TRUE) \method{print}{compare.loo_ss}(x, ..., digits = 1) } @@ -30,6 +30,11 @@ printing.} \item{p_worse}{For the print method only, should we include the normal approximation based probability of each model having worse performance than the best model? The default is \code{TRUE}.} + +\item{simplify}{For the print method only, should the output be simplified +to only include the model names and ELPD differences? The default is +\code{TRUE}. If \code{FALSE}, the full comparison table is printed including +pointwise ELPD, LOOIC/WAIC, and their standard errors for each model.} } \value{ A data frame with class \code{"compare.loo"} that has its own @@ -119,6 +124,9 @@ loo3 <- loo(LL + 2) # should be best model when compared comp <- loo_compare(loo1, loo2, loo3) print(comp, digits = 2) +# print full table with pointwise ELPD and LOOIC +print(comp, simplify = FALSE) + # can use a list of objects with custom names # the names will be used in the output loo_compare(list("apple" = loo1, "banana" = loo2, "cherry" = loo3)) diff --git a/tests/testthat/test_compare.R b/tests/testthat/test_compare.R index 4a00a10a..998dac3f 100644 --- a/tests/testthat/test_compare.R +++ b/tests/testthat/test_compare.R @@ -109,6 +109,13 @@ test_that("loo_compare returns expected results (2 models)", { expect_snapshot_value(comp2, style = "serialize") expect_snapshot(print(comp2)) expect_snapshot(print(comp2, p_worse = FALSE)) + out_full <- paste( + capture.output(suppressMessages(print(comp2, simplify = FALSE))), + collapse = "\n" + ) + expect_match(out_full, "elpd_waic\\s+se_elpd_waic") + expect_match(out_full, "p_waic\\s+se_p_waic\\s+waic\\s+se_waic") + expect_message(print(comp2, simplify = FALSE), "Diagnostic flags present.", fixed = TRUE) # specifying objects via ... and via arg x gives equal results expect_equal(comp2, loo_compare(x = list(w1, w2)))