Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise compare_models() for Bayesian models #780

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 29 additions & 22 deletions R/compare_performance.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#' @return A data frame with one row per model and one column per "index" (see
#' `metrics`).
#'
#' @note There is also a [`plot()`-method](https://easystats.github.io/see/articles/performance.html) implemented in the \href{https://easystats.github.io/see/}{\pkg{see}-package}.

Check warning on line 20 in R/compare_performance.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/compare_performance.R,line=20,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 180 characters.
#'
#' @details \subsection{Model Weights}{
#' When information criteria (IC) are requested in `metrics` (i.e., any of `"all"`,
Expand Down Expand Up @@ -89,7 +89,10 @@
model_objects <- insight::ellipsis_info(..., only_models = TRUE)

# ensure proper object names
model_objects <- .check_objectnames(model_objects, sapply(match.call(expand.dots = FALSE)[["..."]], as.character))
model_objects <- .check_objectnames(
model_objects,
sapply(match.call(expand.dots = FALSE)[["..."]], as.character)
)

# drop unsupport models
supported_models <- sapply(model_objects, function(i) insight::is_model_supported(i) | inherits(i, "lavaan"))
Expand All @@ -105,11 +108,11 @@
}

# iterate over all models, i.e. model-performance for each model
m <- mapply(function(.x, .y) {

Check warning on line 111 in R/compare_performance.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/compare_performance.R,line=111,col=8,[undesirable_function_linter] Avoid undesirable function "mapply".
dat <- model_performance(.x, metrics = metrics, estimator = estimator, verbose = FALSE)
model_name <- gsub("\"", "", insight::safe_deparse(.y), fixed = TRUE)
perf_df <- data.frame(Name = model_name, Model = class(.x)[1], dat, stringsAsFactors = FALSE)
attributes(perf_df) <- c(attributes(perf_df), attributes(dat)[!names(attributes(dat)) %in% c("names", "row.names", "class")])

Check warning on line 115 in R/compare_performance.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/compare_performance.R,line=115,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 129 characters.
perf_df
}, model_objects, object_names, SIMPLIFY = FALSE)

Expand All @@ -119,14 +122,14 @@
})
dfs <- Reduce(function(x, y) merge(x, y, all = TRUE, sort = FALSE), m)

if (any(c("AIC", "AICc", "BIC", "WAIC") %in% names(dfs))) {
if (any(c("AIC", "AICc", "BIC", "WAIC") %in% colnames(dfs))) {
dfs$AIC_wt <- .ic_weight(dfs[["AIC"]])
dfs$AICc_wt <- .ic_weight(dfs[["AICc"]])
dfs$BIC_wt <- .ic_weight(dfs[["BIC"]])
dfs$WAIC_wt <- .ic_weight(dfs[["WAIC"]])
}

if ("LOOIC" %in% names(dfs)) {
if ("LOOIC" %in% colnames(dfs)) {
lpd_point <- do.call(cbind, lapply(attri, function(x) x$loo$pointwise[, "elpd_loo"]))
dfs$LOOIC_wt <- as.numeric(loo::stacking_weights(lpd_point))
}
Expand All @@ -144,56 +147,60 @@
}

# Reorder columns
if (all(c("BIC", "BF") %in% names(dfs))) {
idx1 <- grep("^BIC$", names(dfs))
idx2 <- grep("BF", names(dfs), fixed = TRUE)
if (all(c("BIC", "BF") %in% colnames(dfs))) {
idx1 <- grep("^BIC$", colnames(dfs))
idx2 <- grep("BF", colnames(dfs), fixed = TRUE)
last_part <- (idx1 + 1):ncol(dfs)
dfs <- dfs[, c(1:idx1, idx2, last_part[last_part != idx2])]
}
if (all(c("AIC", "AIC_wt") %in% names(dfs))) {
idx1 <- grep("^AIC$", names(dfs))
idx2 <- grep("AIC_wt", names(dfs), fixed = TRUE)
if (all(c("AIC", "AIC_wt") %in% colnames(dfs))) {
idx1 <- grep("^AIC$", colnames(dfs))
idx2 <- grep("AIC_wt", colnames(dfs), fixed = TRUE)
last_part <- (idx1 + 1):ncol(dfs)
dfs <- dfs[, c(1:idx1, idx2, last_part[last_part != idx2])]
}
if (all(c("BIC", "BIC_wt") %in% names(dfs))) {
idx1 <- grep("^BIC$", names(dfs))
idx2 <- grep("BIC_wt", names(dfs), fixed = TRUE)
if (all(c("BIC", "BIC_wt") %in% colnames(dfs))) {
idx1 <- grep("^BIC$", colnames(dfs))
idx2 <- grep("BIC_wt", colnames(dfs), fixed = TRUE)
last_part <- (idx1 + 1):ncol(dfs)
dfs <- dfs[, c(1:idx1, idx2, last_part[last_part != idx2])]
}
if (all(c("AICc", "AICc_wt") %in% names(dfs))) {
idx1 <- grep("^AICc$", names(dfs))
idx2 <- grep("AICc_wt", names(dfs), fixed = TRUE)
if (all(c("AICc", "AICc_wt") %in% colnames(dfs))) {
idx1 <- grep("^AICc$", colnames(dfs))
idx2 <- grep("AICc_wt", colnames(dfs), fixed = TRUE)
last_part <- (idx1 + 1):ncol(dfs)
dfs <- dfs[, c(1:idx1, idx2, last_part[last_part != idx2])]
}
if (all(c("WAIC", "WAIC_wt") %in% names(dfs))) {
idx1 <- grep("^WAIC$", names(dfs))
idx2 <- grep("WAIC_wt", names(dfs), fixed = TRUE)
if (all(c("WAIC", "WAIC_wt") %in% colnames(dfs))) {
idx1 <- grep("^WAIC$", colnames(dfs))
idx2 <- grep("WAIC_wt", colnames(dfs), fixed = TRUE)
last_part <- (idx1 + 1):ncol(dfs)
dfs <- dfs[, c(1:idx1, idx2, last_part[last_part != idx2])]
}
if (all(c("LOOIC", "LOOIC_wt") %in% names(dfs))) {
idx1 <- grep("^LOOIC$", names(dfs))
idx2 <- grep("LOOIC_wt", names(dfs), fixed = TRUE)
if (all(c("LOOIC", "LOOIC_wt") %in% colnames(dfs))) {
idx1 <- grep("^LOOIC$", colnames(dfs))
idx2 <- grep("LOOIC_wt", colnames(dfs), fixed = TRUE)
last_part <- (idx1 + 1):ncol(dfs)
dfs <- dfs[, c(1:idx1, idx2, last_part[last_part != idx2])]
}
if (any(startsWith(colnames(dfs), "R2"))) {
idx <- which(startsWith(colnames(dfs), "R2"))
dfs <- dfs[, c(1, 2, idx, setdiff(seq_len(ncol(dfs)), idx))]
}

# for REML fits, warn user
if (isTRUE(verbose) &&
# only warn for REML fit
identical(estimator, "REML") &&
# only for IC comparison
any(grepl("(AIC|BIC)", names(dfs))) &&
any(grepl("(AIC|BIC)", colnames(dfs))) &&
# only when mixed models are involved, others probably don't have problems with REML fit
any(sapply(model_objects, insight::is_mixed_model)) &&
# only if not all models have same fixed effects (else, REML is ok)
!isTRUE(attributes(model_objects)$same_fixef)) {
insight::format_alert(
"Information criteria (like AIC) are based on REML fits (i.e. `estimator=\"REML\"`).",
"Please note that information criteria are probably not directly comparable and that it is not recommended comparing models with different fixed effects in such cases."

Check warning on line 203 in R/compare_performance.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/compare_performance.R,line=203,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 174 characters.
)
}

Expand All @@ -212,7 +219,7 @@
formatted_table <- format(x = x, digits = digits, format = "text", ...)

if ("Performance_Score" %in% colnames(formatted_table)) {
footer <- c(sprintf("\nModel `%s` (of class `%s`) performed best with an overall performance score of %s.", formatted_table$Model[1], formatted_table$Type[1], formatted_table$Performance_Score[1]), "yellow")

Check warning on line 222 in R/compare_performance.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/compare_performance.R,line=222,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 211 characters.
} else {
footer <- NULL
}
Expand All @@ -224,7 +231,7 @@
colnames(formatted_table)[1] <- "Metric"
}

cat(insight::export_table(x = formatted_table, digits = digits, format = "text", caption = table_caption, footer = footer, ...))

Check warning on line 234 in R/compare_performance.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/compare_performance.R,line=234,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 130 characters.
invisible(x)
}

Expand Down
Loading