diff --git a/R/cal-estimate-beta.R b/R/cal-estimate-beta.R index bc580878..287ad56c 100644 --- a/R/cal-estimate-beta.R +++ b/R/cal-estimate-beta.R @@ -34,8 +34,13 @@ cal_estimate_beta.data.frame <- function(.data, location_params = 1, estimate = dplyr::starts_with(".pred_"), parameters = NULL, - ...) { + ..., + group = NULL) { stop_null_parameters(parameters) + + group <- get_group_argument({{ group }}, .data) + .data <- dplyr::group_by(.data, dplyr::across({{ group }})) + cal_beta_impl( .data = .data, truth = {{ truth }}, @@ -60,7 +65,6 @@ cal_estimate_beta.tune_results <- function(.data, .data = .data, truth = {{ truth }}, estimate = {{ estimate }}, - group = NULL, event_level = "first", parameters = parameters, ... @@ -227,7 +231,6 @@ cal_beta_impl_single <- function(.data, beta_model } - check_cal_groups <- function(group, .data, call = rlang::env_parent()) { group <- enquo(group) if (!any(names(.data) == ".config")) { diff --git a/R/cal-estimate-isotonic.R b/R/cal-estimate-isotonic.R index 6207131c..9dad6d0e 100644 --- a/R/cal-estimate-isotonic.R +++ b/R/cal-estimate-isotonic.R @@ -42,8 +42,13 @@ cal_estimate_isotonic.data.frame <- function(.data, truth = NULL, estimate = dplyr::starts_with(".pred"), parameters = NULL, - ...) { + ..., + group = NULL) { stop_null_parameters(parameters) + + group <- get_group_argument({{ group }}, .data) + .data <- dplyr::group_by(.data, dplyr::across({{ group }})) + cal_isoreg_impl( .data = .data, truth = {{ truth }}, @@ -64,7 +69,6 @@ cal_estimate_isotonic.tune_results <- function(.data, .data = .data, truth = {{ truth }}, estimate = {{ estimate }}, - group = NULL, event_level = "first", parameters = parameters, ... @@ -117,8 +121,13 @@ cal_estimate_isotonic_boot.data.frame <- function(.data, estimate = dplyr::starts_with(".pred"), times = 10, parameters = NULL, - ...) { + ..., + group = NULL) { stop_null_parameters(parameters) + + group <- get_group_argument({{ group }}, .data) + .data <- dplyr::group_by(.data, dplyr::across({{ group }})) + cal_isoreg_impl( .data = .data, truth = {{ truth }}, @@ -141,7 +150,6 @@ cal_estimate_isotonic_boot.tune_results <- function(.data, .data = .data, truth = {{ truth }}, estimate = {{ estimate }}, - group = NULL, event_level = "first", # or null for regression parameters = parameters, ... diff --git a/R/cal-estimate-linear.R b/R/cal-estimate-linear.R index 674f0b29..74ef3e73 100644 --- a/R/cal-estimate-linear.R +++ b/R/cal-estimate-linear.R @@ -1,5 +1,6 @@ #------------------------------- Methods --------------------------------------- #' Uses a linear regression model to calibrate numeric predictions +#' @inheritParams cal_estimate_logistic #' @param .data A `data.frame` object, or `tune_results` object, that contains #' predictions and probability columns. #' @param truth The column identifier for the observed outcome data (that is @@ -67,7 +68,8 @@ cal_estimate_linear <- function(.data, estimate = dplyr::matches("^.pred$"), smooth = TRUE, parameters = NULL, - ...) { + ..., + group = NULL) { UseMethod("cal_estimate_linear") } @@ -78,8 +80,13 @@ cal_estimate_linear.data.frame <- function(.data, estimate = dplyr::matches("^.pred$"), smooth = TRUE, parameters = NULL, - ...) { + ..., + group = NULL) { stop_null_parameters(parameters) + + group <- get_group_argument({{ group }}, .data) + .data <- dplyr::group_by(.data, dplyr::across({{ group }})) + cal_linear_impl( .data = .data, truth = {{ truth }}, @@ -102,7 +109,6 @@ cal_estimate_linear.tune_results <- function(.data, .data = .data, truth = {{ truth }}, estimate = {{ estimate }}, - group = NULL, event_level = NA_character_, parameters = parameters, ... diff --git a/R/cal-estimate-logistic.R b/R/cal-estimate-logistic.R index 54a343a3..9e926741 100644 --- a/R/cal-estimate-logistic.R +++ b/R/cal-estimate-logistic.R @@ -12,6 +12,9 @@ #' @param parameters (Optional) An optional tibble of tuning parameter values #' that can be used to filter the predicted values before processing. Applies #' only to `tune_results` objects. +#' @param group The column identifier for the grouping variable. This should be +#' one or more unquoted column name. Default to `NULL`. When `group = NULL` no +#' grouping will take place. #' @param ... Additional arguments passed to the models or routines used to #' calculate the new probabilities. #' @param smooth Applies to the logistic models. It switches between logistic @@ -54,8 +57,13 @@ cal_estimate_logistic.data.frame <- function(.data, estimate = dplyr::starts_with(".pred_"), smooth = TRUE, parameters = NULL, - ...) { + ..., + group = NULL) { stop_null_parameters(parameters) + + group <- get_group_argument({{ group }}, .data) + .data <- dplyr::group_by(.data, dplyr::across({{ group }})) + cal_logistic_impl( .data = .data, truth = {{ truth }}, @@ -78,7 +86,6 @@ cal_estimate_logistic.tune_results <- function(.data, .data = .data, truth = {{ truth }}, estimate = {{ estimate }}, - group = NULL, event_level = "first", parameters = parameters, ... diff --git a/R/cal-estimate-multinom.R b/R/cal-estimate-multinom.R index 98421561..da1a3aaa 100644 --- a/R/cal-estimate-multinom.R +++ b/R/cal-estimate-multinom.R @@ -60,9 +60,13 @@ cal_estimate_multinomial.data.frame <- estimate = dplyr::starts_with(".pred_"), smooth = TRUE, parameters = NULL, - ...) { + ..., + group = NULL) { stop_null_parameters(parameters) + group <- get_group_argument({{ group }}, .data) + .data <- dplyr::group_by(.data, dplyr::across({{ group }})) + truth <- enquo(truth) cal_multinom_impl( .data = .data, @@ -87,7 +91,6 @@ cal_estimate_multinomial.tune_results <- .data = .data, truth = {{ truth }}, estimate = {{ estimate }}, - group = NULL, event_level = "first", parameters = parameters, ... diff --git a/R/cal-plot-breaks.R b/R/cal-plot-breaks.R index cd641aee..b157c688 100644 --- a/R/cal-plot-breaks.R +++ b/R/cal-plot-breaks.R @@ -88,7 +88,6 @@ cal_plot_breaks <- function(.data, truth = NULL, estimate = dplyr::starts_with(".pred"), - group = NULL, num_breaks = 10, conf_level = 0.90, include_ribbon = TRUE, @@ -104,16 +103,16 @@ cal_plot_breaks <- function(.data, cal_plot_breaks.data.frame <- function(.data, truth = NULL, estimate = dplyr::starts_with(".pred"), - group = NULL, num_breaks = 10, conf_level = 0.90, include_ribbon = TRUE, include_rug = TRUE, include_points = TRUE, event_level = c("auto", "first", "second"), - ...) { - - check_cal_groups({{ group }}, .data) + ..., + group = NULL) { + group <- get_group_argument({{ group }}, .data) + .data <- dplyr::group_by(.data, dplyr::across({{ group }})) cal_plot_breaks_impl( .data = .data, @@ -134,7 +133,6 @@ cal_plot_breaks.data.frame <- function(.data, cal_plot_breaks.tune_results <- function(.data, truth = NULL, estimate = dplyr::starts_with(".pred"), - group = NULL, num_breaks = 10, conf_level = 0.90, include_ribbon = TRUE, @@ -142,15 +140,19 @@ cal_plot_breaks.tune_results <- function(.data, include_points = TRUE, event_level = c("auto", "first", "second"), ...) { - if (rlang::quo_is_null(enquo(group))) { - group <- expr(.config) - } - - cal_plot_breaks_impl( + tune_args <- tune_results_args( .data = .data, truth = {{ truth }}, estimate = {{ estimate }}, - group = {{ group }}, + event_level = event_level, + ... + ) + + cal_plot_breaks_impl( + .data = tune_args$predictions, + truth = !!tune_args$truth, + estimate = !!tune_args$estimate, + group = !!tune_args$group, num_breaks = num_breaks, conf_level = conf_level, include_ribbon = include_ribbon, @@ -293,7 +295,6 @@ cal_plot_breaks_impl <- function(.data, .data = .data, truth = {{ truth }}, estimate = {{ estimate }}, - group = {{ group }}, event_level = event_level, ... ) diff --git a/R/cal-plot-logistic.R b/R/cal-plot-logistic.R index 0e188656..f3e91276 100644 --- a/R/cal-plot-logistic.R +++ b/R/cal-plot-logistic.R @@ -39,7 +39,6 @@ cal_plot_logistic <- function(.data, truth = NULL, estimate = dplyr::starts_with(".pred"), - group = NULL, conf_level = 0.90, smooth = TRUE, include_rug = TRUE, @@ -54,15 +53,15 @@ cal_plot_logistic <- function(.data, cal_plot_logistic.data.frame <- function(.data, truth = NULL, estimate = dplyr::starts_with(".pred"), - group = NULL, conf_level = 0.90, smooth = TRUE, include_rug = TRUE, include_ribbon = TRUE, event_level = c("auto", "first", "second"), - ...) { - - check_cal_groups({{ group }}, .data) + ..., + group = NULL) { + group <- get_group_argument({{ group }}, .data) + .data <- dplyr::group_by(.data, dplyr::across({{ group }})) cal_plot_logistic_impl( .data = .data, @@ -82,22 +81,25 @@ cal_plot_logistic.data.frame <- function(.data, cal_plot_logistic.tune_results <- function(.data, truth = NULL, estimate = dplyr::starts_with(".pred"), - group = NULL, conf_level = 0.90, smooth = TRUE, include_rug = TRUE, include_ribbon = TRUE, event_level = c("auto", "first", "second"), ...) { - if (rlang::quo_is_null(enquo(group))) { - group <- expr(.config) - } - - cal_plot_logistic_impl( + tune_args <- tune_results_args( .data = .data, truth = {{ truth }}, estimate = {{ estimate }}, - group = {{ group }}, + event_level = event_level, + ... + ) + + cal_plot_logistic_impl( + .data = tune_args$predictions, + truth = !!tune_args$truth, + estimate = !!tune_args$estimate, + group = !!tune_args$group, conf_level = conf_level, include_ribbon = include_ribbon, include_rug = include_rug, @@ -201,7 +203,6 @@ cal_plot_logistic_impl <- function(.data, .data = .data, truth = {{ truth }}, estimate = {{ estimate }}, - group = {{ group }}, event_level = event_level, ... ) diff --git a/R/cal-plot-regression.R b/R/cal-plot-regression.R index e102ffa6..04973b93 100644 --- a/R/cal-plot-regression.R +++ b/R/cal-plot-regression.R @@ -28,7 +28,6 @@ cal_plot_regression <- function(.data, truth = NULL, estimate = NULL, - group = NULL, smooth = TRUE, ...) { UseMethod("cal_plot_regression") @@ -37,11 +36,10 @@ cal_plot_regression <- function(.data, cal_plot_regression_impl <- function(.data, truth = NULL, estimate = NULL, - group = NULL, smooth = TRUE, - ...) { - - check_cal_groups({{ group }}, .data) + ..., + group = NULL) { + group <- get_group_argument({{ group }}, .data) truth <- enquo(truth) estimate <- enquo(estimate) @@ -68,14 +66,12 @@ cal_plot_regression.data.frame <- cal_plot_regression_impl cal_plot_regression.tune_results <- function(.data, truth = NULL, estimate = NULL, - group = NULL, smooth = TRUE, ...) { tune_args <- tune_results_args( .data = .data, truth = {{ truth }}, estimate = {{ estimate }}, - group = {{ group }}, ... ) @@ -95,6 +91,10 @@ regression_plot_impl <- function(.data, truth, estimate, group, estimate <- enquo(estimate) group <- enquo(group) + if (quo_is_null(group)) { + .data[[".config"]] <- NULL + } + gp_vars <- dplyr::group_vars(.data) if (length(gp_vars)) { diff --git a/R/cal-plot-utils.R b/R/cal-plot-utils.R index 51e06ec3..3fb00145 100644 --- a/R/cal-plot-utils.R +++ b/R/cal-plot-utils.R @@ -217,7 +217,6 @@ process_level <- function(x) { tune_results_args <- function(.data, truth, estimate, - group, event_level, parameters = NULL, ...) { @@ -239,7 +238,6 @@ tune_results_args <- function(.data, truth <- enquo(truth) estimate <- enquo(estimate) - group <- enquo(group) if (quo_is_null(truth)) { truth_str <- attributes(.data)$outcome @@ -250,15 +248,17 @@ tune_results_args <- function(.data, estimate <- expr(dplyr::starts_with(".pred")) } - if (quo_is_null(group)) { + if (dplyr::n_distinct(.data[[".predictions"]][[1]][[".config"]]) > 1) { group <- quo(.config) + } else { + group <- quo(NULL) } list( truth = quo(!!truth), estimate = quo(!!estimate), estimate = estimate, - group = quo(!!group), + group = group, predictions = predictions ) } diff --git a/R/cal-plot-windowed.R b/R/cal-plot-windowed.R index 26dc1bec..8593ec4e 100644 --- a/R/cal-plot-windowed.R +++ b/R/cal-plot-windowed.R @@ -42,7 +42,6 @@ cal_plot_windowed <- function(.data, truth = NULL, estimate = dplyr::starts_with(".pred"), - group = NULL, window_size = 0.1, step_size = window_size / 2, conf_level = 0.90, @@ -59,7 +58,6 @@ cal_plot_windowed <- function(.data, cal_plot_windowed.data.frame <- function(.data, truth = NULL, estimate = dplyr::starts_with(".pred"), - group = NULL, window_size = 0.1, step_size = window_size / 2, conf_level = 0.90, @@ -67,8 +65,10 @@ cal_plot_windowed.data.frame <- function(.data, include_rug = TRUE, include_points = TRUE, event_level = c("auto", "first", "second"), - ...) { - check_cal_groups({{ group }}, .data) + ..., + group = NULL) { + group <- get_group_argument({{ group }}, .data) + .data <- dplyr::group_by(.data, dplyr::across({{ group }})) cal_plot_windowed_impl( .data = .data, @@ -91,7 +91,6 @@ cal_plot_windowed.data.frame <- function(.data, cal_plot_windowed.tune_results <- function(.data, truth = NULL, estimate = dplyr::starts_with(".pred"), - group = NULL, window_size = 0.1, step_size = window_size / 2, conf_level = 0.90, @@ -100,15 +99,19 @@ cal_plot_windowed.tune_results <- function(.data, include_points = TRUE, event_level = c("auto", "first", "second"), ...) { - if (rlang::quo_is_null(enquo(group))) { - group <- expr(.config) - } - - cal_plot_windowed_impl( + tune_args <- tune_results_args( .data = .data, truth = {{ truth }}, estimate = {{ estimate }}, - group = {{ group }}, + event_level = event_level, + ... + ) + + cal_plot_windowed_impl( + .data = tune_args$predictions, + truth = !!tune_args$truth, + estimate = !!tune_args$estimate, + group = !!tune_args$group, window_size = window_size, step_size = step_size, conf_level = conf_level, @@ -221,7 +224,6 @@ cal_plot_windowed_impl <- function(.data, .data = .data, truth = {{ truth }}, estimate = {{ estimate }}, - group = {{ group }}, event_level = event_level, ... ) diff --git a/R/utils.R b/R/utils.R index 56a53f35..452e4662 100644 --- a/R/utils.R +++ b/R/utils.R @@ -44,3 +44,38 @@ is_ordered.class_pred <- function(x) { is_ordered.default <- function(x) { is.ordered(x) } + +get_group_argument <- function(group, .data, call = rlang::env_parent()) { + group <- rlang::enquo(group) + + group_names <- tidyselect::eval_select( + expr = group, + data = .data, + allow_rename = FALSE, + allow_empty = TRUE, + allow_predicates = TRUE, + error_call = call + ) + + n_group_names <- length(group_names) + + useable_config <- n_group_names == 0 && + ".config" %in% names(.data) && + dplyr::n_distinct(.data[[".config"]]) > 1 + + if (useable_config) { + return(quo(.config)) + } + + if (n_group_names > 1) { + cli::cli_abort( + c( + x = "{.arg group} cannot select more than one column.", + i = "The following {n_group_names} columns were selected:", + i = "{names(group_names)}" + ) + ) + } + + return(group) +} diff --git a/man/cal_estimate_beta.Rd b/man/cal_estimate_beta.Rd index 3b8bc8ed..0fc6e81e 100644 --- a/man/cal_estimate_beta.Rd +++ b/man/cal_estimate_beta.Rd @@ -23,7 +23,8 @@ cal_estimate_beta( location_params = 1, estimate = dplyr::starts_with(".pred_"), parameters = NULL, - ... + ..., + group = NULL ) \method{cal_estimate_beta}{tune_results}( @@ -61,6 +62,10 @@ only to \code{tune_results} objects.} \item{...}{Additional arguments passed to the models or routines used to calculate the new probabilities.} + +\item{group}{The column identifier for the grouping variable. This should be +one or more unquoted column name. Default to \code{NULL}. When \code{group = NULL} no +grouping will take place.} } \description{ Uses a Beta calibration model to calculate new probabilities diff --git a/man/cal_estimate_isotonic.Rd b/man/cal_estimate_isotonic.Rd index c6afb0eb..6b824dda 100644 --- a/man/cal_estimate_isotonic.Rd +++ b/man/cal_estimate_isotonic.Rd @@ -19,7 +19,8 @@ cal_estimate_isotonic( truth = NULL, estimate = dplyr::starts_with(".pred"), parameters = NULL, - ... + ..., + group = NULL ) \method{cal_estimate_isotonic}{tune_results}( @@ -49,6 +50,10 @@ only to \code{tune_results} objects.} \item{...}{Additional arguments passed to the models or routines used to calculate the new probabilities.} + +\item{group}{The column identifier for the grouping variable. This should be +one or more unquoted column name. Default to \code{NULL}. When \code{group = NULL} no +grouping will take place.} } \description{ Uses an Isotonic regression model to calibrate model predictions. diff --git a/man/cal_estimate_isotonic_boot.Rd b/man/cal_estimate_isotonic_boot.Rd index d10edc46..bd8b8957 100644 --- a/man/cal_estimate_isotonic_boot.Rd +++ b/man/cal_estimate_isotonic_boot.Rd @@ -21,7 +21,8 @@ cal_estimate_isotonic_boot( estimate = dplyr::starts_with(".pred"), times = 10, parameters = NULL, - ... + ..., + group = NULL ) \method{cal_estimate_isotonic_boot}{tune_results}( @@ -54,6 +55,10 @@ only to \code{tune_results} objects.} \item{...}{Additional arguments passed to the models or routines used to calculate the new probabilities.} + +\item{group}{The column identifier for the grouping variable. This should be +one or more unquoted column name. Default to \code{NULL}. When \code{group = NULL} no +grouping will take place.} } \description{ Uses a bootstrapped Isotonic regression model to calibrate probabilities diff --git a/man/cal_estimate_linear.Rd b/man/cal_estimate_linear.Rd index 43059557..09d60712 100644 --- a/man/cal_estimate_linear.Rd +++ b/man/cal_estimate_linear.Rd @@ -12,7 +12,8 @@ cal_estimate_linear( estimate = dplyr::matches("^.pred$"), smooth = TRUE, parameters = NULL, - ... + ..., + group = NULL ) \method{cal_estimate_linear}{data.frame}( @@ -21,7 +22,8 @@ cal_estimate_linear( estimate = dplyr::matches("^.pred$"), smooth = TRUE, parameters = NULL, - ... + ..., + group = NULL ) \method{cal_estimate_linear}{tune_results}( @@ -52,6 +54,10 @@ only to \code{tune_results} objects.} \item{...}{Additional arguments passed to the models or routines used to calculate the new predictions.} + +\item{group}{The column identifier for the grouping variable. This should be +one or more unquoted column name. Default to \code{NULL}. When \code{group = NULL} no +grouping will take place.} } \description{ Uses a linear regression model to calibrate numeric predictions diff --git a/man/cal_estimate_logistic.Rd b/man/cal_estimate_logistic.Rd index 5ca44525..30506749 100644 --- a/man/cal_estimate_logistic.Rd +++ b/man/cal_estimate_logistic.Rd @@ -21,7 +21,8 @@ cal_estimate_logistic( estimate = dplyr::starts_with(".pred_"), smooth = TRUE, parameters = NULL, - ... + ..., + group = NULL ) \method{cal_estimate_logistic}{tune_results}( @@ -55,6 +56,10 @@ only to \code{tune_results} objects.} \item{...}{Additional arguments passed to the models or routines used to calculate the new probabilities.} + +\item{group}{The column identifier for the grouping variable. This should be +one or more unquoted column name. Default to \code{NULL}. When \code{group = NULL} no +grouping will take place.} } \description{ Uses a logistic regression model to calibrate probabilities diff --git a/man/cal_estimate_multinomial.Rd b/man/cal_estimate_multinomial.Rd index 481e8cd3..1a67fa7f 100644 --- a/man/cal_estimate_multinomial.Rd +++ b/man/cal_estimate_multinomial.Rd @@ -21,7 +21,8 @@ cal_estimate_multinomial( estimate = dplyr::starts_with(".pred_"), smooth = TRUE, parameters = NULL, - ... + ..., + group = NULL ) \method{cal_estimate_multinomial}{tune_results}( @@ -55,6 +56,10 @@ only to \code{tune_results} objects.} \item{...}{Additional arguments passed to the models or routines used to calculate the new probabilities.} + +\item{group}{The column identifier for the grouping variable. This should be +one or more unquoted column name. Default to \code{NULL}. When \code{group = NULL} no +grouping will take place.} } \description{ Uses a Multinomial calibration model to calculate new probabilities diff --git a/man/cal_plot_breaks.Rd b/man/cal_plot_breaks.Rd index d687fd0a..cb57fd54 100644 --- a/man/cal_plot_breaks.Rd +++ b/man/cal_plot_breaks.Rd @@ -10,7 +10,6 @@ cal_plot_breaks( .data, truth = NULL, estimate = dplyr::starts_with(".pred"), - group = NULL, num_breaks = 10, conf_level = 0.9, include_ribbon = TRUE, @@ -24,21 +23,20 @@ cal_plot_breaks( .data, truth = NULL, estimate = dplyr::starts_with(".pred"), - group = NULL, num_breaks = 10, conf_level = 0.9, include_ribbon = TRUE, include_rug = TRUE, include_points = TRUE, event_level = c("auto", "first", "second"), - ... + ..., + group = NULL ) \method{cal_plot_breaks}{tune_results}( .data, truth = NULL, estimate = dplyr::starts_with(".pred"), - group = NULL, num_breaks = 10, conf_level = 0.9, include_ribbon = TRUE, @@ -60,8 +58,6 @@ defaults to the prefix used by tidymodels (\code{.pred_}). The order of the identifiers will be considered the same as the order of the levels of the \code{truth} variable.} -\item{group}{The column identifier to group the results.} - \item{num_breaks}{The number of segments to group the probabilities. It defaults to 10.} @@ -83,6 +79,8 @@ the function decide which one to use based on the type of model (binary, multi-class or linear)} \item{...}{Additional arguments passed to the \code{tune_results} object.} + +\item{group}{The column identifier to group the results.} } \value{ A ggplot object. diff --git a/man/cal_plot_logistic.Rd b/man/cal_plot_logistic.Rd index f849957e..1de9d6b6 100644 --- a/man/cal_plot_logistic.Rd +++ b/man/cal_plot_logistic.Rd @@ -10,7 +10,6 @@ cal_plot_logistic( .data, truth = NULL, estimate = dplyr::starts_with(".pred"), - group = NULL, conf_level = 0.9, smooth = TRUE, include_rug = TRUE, @@ -23,20 +22,19 @@ cal_plot_logistic( .data, truth = NULL, estimate = dplyr::starts_with(".pred"), - group = NULL, conf_level = 0.9, smooth = TRUE, include_rug = TRUE, include_ribbon = TRUE, event_level = c("auto", "first", "second"), - ... + ..., + group = NULL ) \method{cal_plot_logistic}{tune_results}( .data, truth = NULL, estimate = dplyr::starts_with(".pred"), - group = NULL, conf_level = 0.9, smooth = TRUE, include_rug = TRUE, @@ -57,8 +55,6 @@ defaults to the prefix used by tidymodels (\code{.pred_}). The order of the identifiers will be considered the same as the order of the levels of the \code{truth} variable.} -\item{group}{The column identifier to group the results.} - \item{conf_level}{Confidence level to use in the visualization. It defaults to 0.9.} @@ -78,6 +74,8 @@ the function decide which one to use based on the type of model (binary, multi-class or linear)} \item{...}{Additional arguments passed to the \code{tune_results} object.} + +\item{group}{The column identifier to group the results.} } \value{ A ggplot object. diff --git a/man/cal_plot_regression.Rd b/man/cal_plot_regression.Rd index 224247db..86c55e98 100644 --- a/man/cal_plot_regression.Rd +++ b/man/cal_plot_regression.Rd @@ -6,32 +6,18 @@ \alias{cal_plot_regression.tune_results} \title{Regression calibration plots} \usage{ -cal_plot_regression( - .data, - truth = NULL, - estimate = NULL, - group = NULL, - smooth = TRUE, - ... -) +cal_plot_regression(.data, truth = NULL, estimate = NULL, smooth = TRUE, ...) \method{cal_plot_regression}{data.frame}( .data, truth = NULL, estimate = NULL, - group = NULL, smooth = TRUE, - ... + ..., + group = NULL ) -\method{cal_plot_regression}{tune_results}( - .data, - truth = NULL, - estimate = NULL, - group = NULL, - smooth = TRUE, - ... -) +\method{cal_plot_regression}{tune_results}(.data, truth = NULL, estimate = NULL, smooth = TRUE, ...) } \arguments{ \item{.data}{A data.frame object containing prediction and truth columns.} @@ -42,12 +28,12 @@ cal_plot_regression( \item{estimate}{The column identifier for the predictions. This should be an unquoted column name} -\item{group}{The column identifier to group the results. This should not be -a numeric variable.} - \item{smooth}{A logical: should a smoother curve be added.} \item{...}{Additional arguments passed to \code{\link[ggplot2:geom_point]{ggplot2::geom_point()}}.} + +\item{group}{The column identifier to group the results. This should not be +a numeric variable.} } \value{ A ggplot object. diff --git a/man/cal_plot_windowed.Rd b/man/cal_plot_windowed.Rd index f2419dc4..cdf72317 100644 --- a/man/cal_plot_windowed.Rd +++ b/man/cal_plot_windowed.Rd @@ -10,7 +10,6 @@ cal_plot_windowed( .data, truth = NULL, estimate = dplyr::starts_with(".pred"), - group = NULL, window_size = 0.1, step_size = window_size/2, conf_level = 0.9, @@ -25,7 +24,6 @@ cal_plot_windowed( .data, truth = NULL, estimate = dplyr::starts_with(".pred"), - group = NULL, window_size = 0.1, step_size = window_size/2, conf_level = 0.9, @@ -33,14 +31,14 @@ cal_plot_windowed( include_rug = TRUE, include_points = TRUE, event_level = c("auto", "first", "second"), - ... + ..., + group = NULL ) \method{cal_plot_windowed}{tune_results}( .data, truth = NULL, estimate = dplyr::starts_with(".pred"), - group = NULL, window_size = 0.1, step_size = window_size/2, conf_level = 0.9, @@ -63,8 +61,6 @@ defaults to the prefix used by tidymodels (\code{.pred_}). The order of the identifiers will be considered the same as the order of the levels of the \code{truth} variable.} -\item{group}{The column identifier to group the results.} - \item{window_size}{The size of segments. Used for the windowed probability calculations. It defaults to 10\% of segments.} @@ -89,6 +85,8 @@ the function decide which one to use based on the type of model (binary, multi-class or linear)} \item{...}{Additional arguments passed to the \code{tune_results} object.} + +\item{group}{The column identifier to group the results.} } \value{ A ggplot object. diff --git a/tests/testthat/_snaps/cal-estimate.md b/tests/testthat/_snaps/cal-estimate.md index fe74db7a..388161d1 100644 --- a/tests/testthat/_snaps/cal-estimate.md +++ b/tests/testthat/_snaps/cal-estimate.md @@ -22,6 +22,28 @@ The number of outcome factor levels isn't consistent with the calibration method. Only two class `truth` factors are allowed. The given levels were: 'VF', 'F', 'M', 'L' +--- + + Code + print(sl_logistic_group) + Message + + -- Probability Calibration + Method: Logistic + Type: Binary + Source class: Data Frame + Data points: 1,010, split in 2 groups + Truth variable: `Class` + Estimate variables: + `.pred_good` ==> good + `.pred_poor` ==> poor + +--- + + x `group` cannot select more than one column. + i The following 2 columns were selected: + i group1 and group2 + # Logistic estimates work - tune_results Code @@ -38,6 +60,10 @@ `.pred_class_1` ==> class_1 `.pred_class_2` ==> class_2 +--- + + The number of outcome factor levels isn't consistent with the calibration method. Only two class `truth` factors are allowed. The given levels were: '.pred_one', '.pred_two', '.pred_three' + # Logistic spline estimates work - data.frame Code @@ -54,6 +80,28 @@ `.pred_good` ==> good `.pred_poor` ==> poor +--- + + Code + print(sl_gam_group) + Message + + -- Probability Calibration + Method: Logistic Spline + Type: Binary + Source class: Data Frame + Data points: 1,010, split in 2 groups + Truth variable: `Class` + Estimate variables: + `.pred_good` ==> good + `.pred_poor` ==> poor + +--- + + x `group` cannot select more than one column. + i The following 2 columns were selected: + i group1 and group2 + # Logistic spline estimates work - tune_results Code @@ -87,6 +135,29 @@ `.pred_good` ==> good `.pred_poor` ==> poor +--- + + Code + print(sl_isotonic_group) + Message + + -- Probability Calibration + Method: Isotonic + Type: Binary + Source class: Data Frame + Data points: 1,010, split in 2 groups + Unique Predicted Values: 19 + Truth variable: `Class` + Estimate variables: + `.pred_good` ==> good + `.pred_poor` ==> poor + +--- + + x `group` cannot select more than one column. + i The following 2 columns were selected: + i group1 and group2 + # Isotonic estimates work - tune_results Code @@ -104,6 +175,23 @@ `.pred_class_1` ==> class_1 `.pred_class_2` ==> class_2 +--- + + Code + print(mtnl_isotonic) + Message + + -- Probability Calibration + Method: Isotonic + Type: Multiclass (1 v All) + Source class: + Data points: 5,000, split in 10 groups + Truth variable: `class` + Estimate variables: + `.pred_one` ==> one + `.pred_two` ==> two + `.pred_three` ==> three + # Isotonic linear estimates work - data.frame Code @@ -115,12 +203,34 @@ Type: Regression Source class: Data Frame Data points: 2,000 - Unique Predicted Values: 44 + Unique Predicted Values: 43 Truth variable: `outcome` Estimate variables: `.pred` ==> predictions -# Isotonic Bootstrapped estimates work +--- + + Code + print(sl_logistic_group) + Message + + -- Probability Calibration + Method: Isotonic + Type: Regression + Source class: Data Frame + Data points: 2,000, split in 10 groups + Unique Predicted Values: 11 + Truth variable: `outcome` + Estimate variables: + `.pred` ==> predictions + +--- + + x `group` cannot select more than one column. + i The following 2 columns were selected: + i group1 and group2 + +# Isotonic Bootstrapped estimates work - data.frame Code print(sl_boot) @@ -136,6 +246,28 @@ `.pred_good` ==> good `.pred_poor` ==> poor +--- + + Code + print(sl_boot_group) + Message + + -- Probability Calibration + Method: Bootstrapped Isotonic Regression + Type: Binary + Source class: Data Frame + Data points: 1,010, split in 2 groups + Truth variable: `Class` + Estimate variables: + `.pred_good` ==> good + `.pred_poor` ==> poor + +--- + + x `group` cannot select more than one column. + i The following 2 columns were selected: + i group1 and group2 + # Isotonic Bootstrapped estimates work - tune_results Code @@ -152,6 +284,23 @@ `.pred_class_1` ==> class_1 `.pred_class_2` ==> class_2 +--- + + Code + print(mtnl_isotonic) + Message + + -- Probability Calibration + Method: Bootstrapped Isotonic Regression + Type: Multiclass (1 v All) + Source class: Tune Results + Data points: 5,000, split in 10 groups + Truth variable: `class` + Estimate variables: + `.pred_one` ==> one + `.pred_two` ==> two + `.pred_three` ==> three + # Beta estimates work - data.frame Code @@ -168,6 +317,28 @@ `.pred_good` ==> good `.pred_poor` ==> poor +--- + + Code + print(sl_beta_group) + Message + + -- Probability Calibration + Method: Beta + Type: Binary + Source class: Data Frame + Data points: 1,010, split in 2 groups + Truth variable: `Class` + Estimate variables: + `.pred_good` ==> good + `.pred_poor` ==> poor + +--- + + x `group` cannot select more than one column. + i The following 2 columns were selected: + i group1 and group2 + # Beta estimates work - tune_results Code @@ -184,6 +355,23 @@ `.pred_class_1` ==> class_1 `.pred_class_2` ==> class_2 +--- + + Code + print(mtnl_isotonic) + Message + + -- Probability Calibration + Method: Beta + Type: Multiclass (1 v All) + Source class: Tune Results + Data points: 5,000, split in 10 groups + Truth variable: `class` + Estimate variables: + `.pred_one` ==> one + `.pred_two` ==> two + `.pred_three` ==> three + # Multinomial estimates work - data.frame Code @@ -218,6 +406,29 @@ `.pred_coyote` ==> coyote `.pred_gray_fox` ==> gray_fox +--- + + Code + print(sl_multi_group) + Message + + -- Probability Calibration + Method: Multinomial + Type: Multiclass + Source class: Data Frame + Data points: 110, split in 2 groups + Truth variable: `Species` + Estimate variables: + `.pred_bobcat` ==> bobcat + `.pred_coyote` ==> coyote + `.pred_gray_fox` ==> gray_fox + +--- + + x `group` cannot select more than one column. + i The following 2 columns were selected: + i group1 and group2 + # Multinomial estimates work - tune_results Code @@ -265,6 +476,25 @@ Truth variable: `outcome` Estimate variable: `.pred` +--- + + Code + print(sl_logistic_group) + Message + + -- Regression Calibration + Method: Linear + Source class: Data Frame + Data points: 2,000, split in 2 groups + Truth variable: `outcome` + Estimate variable: `.pred` + +--- + + x `group` cannot select more than one column. + i The following 2 columns were selected: + i group1 and group2 + # Linear estimates work - tune_results Code @@ -291,6 +521,25 @@ Truth variable: `outcome` Estimate variable: `.pred` +--- + + Code + print(sl_gam_group) + Message + + -- Regression Calibration + Method: Linear Spline + Source class: Data Frame + Data points: 2,000, split in 2 groups + Truth variable: `outcome` + Estimate variable: `.pred` + +--- + + x `group` cannot select more than one column. + i The following 2 columns were selected: + i group1 and group2 + # Linear spline estimates work - tune_results Code diff --git a/tests/testthat/_snaps/cal-plot.md b/tests/testthat/_snaps/cal-plot.md index 64eda72c..39b0b917 100644 --- a/tests/testthat/_snaps/cal-plot.md +++ b/tests/testthat/_snaps/cal-plot.md @@ -1,48 +1,16 @@ -# Binary breaks functions work +# Binary breaks functions work with group argument - Code - testthat_cal_binary() %>% tune::collect_predictions() %>% cal_plot_breaks(class, - estimate = .pred_class_1) - Condition - Error: - ! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data. + x `group` cannot select more than one column. + i The following 2 columns were selected: + i group1 and group2 -# Multi-class breaks functions work +# Binary logistic functions work with group argument - Code - testthat_cal_multiclass() %>% tune::collect_predictions() %>% cal_plot_breaks( - class, estimate = .pred_class_1) - Condition - Error: - ! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data. - -# Binary logistic functions work - - Code - testthat_cal_binary() %>% tune::collect_predictions() %>% cal_plot_logistic( - class, estimate = .pred_class_1) - Condition - Error: - ! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data. - -# Binary windowed functions work - - Code - testthat_cal_binary() %>% tune::collect_predictions() %>% cal_plot_windowed( - class, estimate = .pred_class_1) - Condition - Error: - ! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data. + x `group` cannot select more than one column. + i The following 2 columns were selected: + i group1 and group2 # Event level handling works Invalid event_level entry: invalid. Valid entries are 'first', 'second', or 'auto' -# regression functions work - - Code - obj %>% tune::collect_predictions() %>% cal_plot_windowed(outcome, estimate = .pred) - Condition - Error: - ! The data have several values of '.config' but no 'groups' argument was passed. This will inappropriately pool the data. - diff --git a/tests/testthat/_snaps/cal-plot/cal_plot_breaks-df-group.png b/tests/testthat/_snaps/cal-plot/cal_plot_breaks-df-group.png new file mode 100644 index 00000000..8ad8dab4 Binary files /dev/null and b/tests/testthat/_snaps/cal-plot/cal_plot_breaks-df-group.png differ diff --git a/tests/testthat/_snaps/cal-plot/cal_plot_logistic-df-group.png b/tests/testthat/_snaps/cal-plot/cal_plot_logistic-df-group.png new file mode 100644 index 00000000..a4497863 Binary files /dev/null and b/tests/testthat/_snaps/cal-plot/cal_plot_logistic-df-group.png differ diff --git a/tests/testthat/cal_files/sim_multi.rds b/tests/testthat/cal_files/sim_multi.rds new file mode 100644 index 00000000..dfb4bfc4 Binary files /dev/null and b/tests/testthat/cal_files/sim_multi.rds differ diff --git a/tests/testthat/helper-cal.R b/tests/testthat/helper-cal.R index 83655164..b3aa0901 100644 --- a/tests/testthat/helper-cal.R +++ b/tests/testthat/helper-cal.R @@ -99,11 +99,22 @@ testthat_cal_multiclass <- function() { ret <- readRDS(ret_file) } .cal_env$tune_results_multi <- ret + cp <- tune::collect_predictions(ret, summarize = TRUE) + .cal_env$tune_results_multi_count <- nrow(cp) } ret } +testthat_cal_multiclass_count <- function() { + ret <- .cal_env$tune_results_multi_count + if (is.null(ret)) { + invisible(testthat_cal_multiclass()) + ret <- .cal_env$tune_results_multi_count + } + ret +} + # -------------------------- >> Multiclass (Sim) ------------------------------- testthat_cal_sim_multi <- function() { @@ -303,3 +314,44 @@ expect_snapshot_plot <- function(name, code) { expect_snapshot_file(path, name) } +has_facet <- function(x) { + inherits(x$facet, c("FacetWrap", "FacetGrid")) +} + +are_groups_configs <- function(x) { + fltrs <- purrr::map(x$estimates, ~ .x$filter) + + # Check if anything is in the filter slot + are_null <- purrr::map_lgl(fltrs, ~ all(is.null(.x))) + if (all(are_null)) { + return(FALSE) + } + + fltr_vars <- purrr::map(fltrs, all.vars) + are_config <- purrr::map_lgl(fltr_vars, ~ identical(.x, ".config")) + all(are_config) +} + +bin_with_configs <- function() { + set.seed(1) + segment_logistic %>% + dplyr::mutate(.config = sample(letters[1:2], nrow(segment_logistic), replace = TRUE)) +} + +mnl_with_configs <- function() { + data("hpc_cv", package = "modeldata") + + set.seed(1) + hpc_cv %>% + dplyr::mutate(.config = sample(letters[1:2], nrow(hpc_cv), replace = TRUE)) +} + +reg_with_configs <- function() { + data("solubility_test", package = "modeldata") + + set.seed(1) + + solubility_test %>% + dplyr::mutate(.config = sample(letters[1:2], nrow(solubility_test), replace = TRUE)) + +} diff --git a/tests/testthat/test-cal-estimate.R b/tests/testthat/test-cal-estimate.R index 35292726..1e7023b7 100644 --- a/tests/testthat/test-cal-estimate.R +++ b/tests/testthat/test-cal-estimate.R @@ -16,6 +16,28 @@ test_that("Logistic estimates work - data.frame", { hpc_cv %>% cal_estimate_logistic(truth = obs, estimate = c(VF:L)) ) + sl_logistic_group <- segment_logistic %>% + dplyr::mutate(group = .pred_poor > 0.5) %>% + cal_estimate_logistic(Class, group = group, smooth = FALSE) + expect_false(are_groups_configs(sl_logistic_group)) + + expect_cal_type(sl_logistic_group, "binary") + expect_cal_method(sl_logistic_group, "Logistic") + expect_cal_estimate(sl_logistic_group, "butchered_glm") + expect_cal_rows(sl_logistic_group) + expect_snapshot(print(sl_logistic_group)) + + expect_snapshot_error( + segment_logistic %>% + dplyr::mutate(group1 = 1, group2 = 2) %>% + cal_estimate_logistic(Class, group = c(group1, group2), smooth = FALSE) + ) + + lgst_configs <- + bin_with_configs() %>% + cal_estimate_logistic(truth = Class, smooth = FALSE) + expect_true(are_groups_configs(lgst_configs)) + }) test_that("Logistic estimates work - tune_results", { @@ -24,6 +46,11 @@ test_that("Logistic estimates work - tune_results", { expect_cal_method(tl_logistic, "Logistic") expect_cal_estimate(tl_logistic, "butchered_glm") expect_snapshot(print(tl_logistic)) + expect_true(are_groups_configs(tl_logistic)) + + expect_snapshot_error( + cal_estimate_logistic(testthat_cal_multiclass(), smooth = FALSE) + ) }) # ----------------------------- Logistic Spline -------------------------------- @@ -34,6 +61,27 @@ test_that("Logistic spline estimates work - data.frame", { expect_cal_estimate(sl_gam, "butchered_gam") expect_cal_rows(sl_gam) expect_snapshot(print(sl_gam)) + + sl_gam_group <- segment_logistic %>% + dplyr::mutate(group = .pred_poor > 0.5) %>% + cal_estimate_logistic(Class, group = group) + + expect_cal_type(sl_gam_group, "binary") + expect_cal_method(sl_gam_group, "Logistic Spline") + expect_cal_estimate(sl_gam_group, "butchered_gam") + expect_cal_rows(sl_gam_group) + expect_snapshot(print(sl_gam_group)) + + expect_snapshot_error( + segment_logistic %>% + dplyr::mutate(group1 = 1, group2 = 2) %>% + cal_estimate_logistic(Class, group = c(group1, group2)) + ) + + lgst_configs <- + bin_with_configs() %>% + cal_estimate_logistic(truth = Class, smooth = TRUE) + expect_true(are_groups_configs(lgst_configs)) }) test_that("Logistic spline estimates work - tune_results", { @@ -42,6 +90,7 @@ test_that("Logistic spline estimates work - tune_results", { expect_cal_method(tl_gam, "Logistic Spline") expect_cal_estimate(tl_gam, "butchered_gam") expect_snapshot(print(tl_gam)) + expect_true(are_groups_configs(tl_gam)) expect_equal( testthat_cal_binary_count(), @@ -57,6 +106,31 @@ test_that("Isotonic estimates work - data.frame", { expect_cal_method(sl_isotonic, "Isotonic") expect_cal_rows(sl_isotonic) expect_snapshot(print(sl_isotonic)) + + sl_isotonic_group <- segment_logistic %>% + dplyr::mutate(group = .pred_poor > 0.5) %>% + cal_estimate_isotonic(Class, group = group) + + expect_cal_type(sl_isotonic_group, "binary") + expect_cal_method(sl_isotonic_group, "Isotonic") + expect_cal_rows(sl_isotonic_group) + expect_snapshot(print(sl_isotonic_group)) + + expect_snapshot_error( + segment_logistic %>% + dplyr::mutate(group1 = 1, group2 = 2) %>% + cal_estimate_isotonic(Class, group = c(group1, group2)) + ) + + iso_configs <- + bin_with_configs() %>% + cal_estimate_isotonic(truth = Class) + expect_true(are_groups_configs(iso_configs)) + + mltm_configs <- + mnl_with_configs() %>% + cal_estimate_isotonic(truth = obs, estimate = c(VF:L)) + expect_true(are_groups_configs(mltm_configs)) }) test_that("Isotonic estimates work - tune_results", { @@ -65,11 +139,27 @@ test_that("Isotonic estimates work - tune_results", { expect_cal_type(tl_isotonic, "binary") expect_cal_method(tl_isotonic, "Isotonic") expect_snapshot(print(tl_isotonic)) + expect_true(are_groups_configs(tl_isotonic)) expect_equal( testthat_cal_binary_count(), nrow(cal_apply(testthat_cal_binary(), tl_isotonic)) ) + + # ------------------------------------------------------------------------------ + # multinomial outcomes + + set.seed(100) + mtnl_isotonic <- cal_estimate_isotonic(testthat_cal_multiclass()) + expect_cal_type(mtnl_isotonic, "one_vs_all") + expect_cal_method(mtnl_isotonic, "Isotonic") + expect_snapshot(print(mtnl_isotonic)) + expect_true(are_groups_configs(mtnl_isotonic)) + + expect_equal( + testthat_cal_multiclass_count(), + nrow(cal_apply(testthat_cal_multiclass(), mtnl_isotonic)) + ) }) test_that("Isotonic linear estimates work - data.frame", { @@ -78,14 +168,60 @@ test_that("Isotonic linear estimates work - data.frame", { expect_cal_method(sl_logistic, "Isotonic") expect_cal_rows(sl_logistic, 2000) expect_snapshot(print(sl_logistic)) + + sl_logistic_group <- boosting_predictions_oob %>% + cal_estimate_isotonic(outcome, estimate = .pred, group = id) + + expect_cal_type(sl_logistic_group, "regression") + expect_cal_method(sl_logistic_group, "Isotonic") + expect_cal_rows(sl_logistic_group, 2000) + expect_snapshot(print(sl_logistic_group)) + + expect_snapshot_error( + boosting_predictions_oob %>% + dplyr::mutate(group1 = 1, group2 = 2) %>% + cal_estimate_isotonic(outcome, estimate = .pred, group = c(group1, group2)) + ) + + iso_configs <- + reg_with_configs() %>% + cal_estimate_isotonic(truth = solubility, estimate = prediction) + expect_true(are_groups_configs(iso_configs)) }) # -------------------------- Isotonic Bootstrapped ----------------------------- -test_that("Isotonic Bootstrapped estimates work", { +test_that("Isotonic Bootstrapped estimates work - data.frame", { + set.seed(1) sl_boot <- cal_estimate_isotonic_boot(segment_logistic, Class) expect_cal_type(sl_boot, "binary") expect_cal_method(sl_boot, "Bootstrapped Isotonic Regression") expect_snapshot(print(sl_boot)) + + sl_boot_group <- segment_logistic %>% + dplyr::mutate(group = .pred_poor > 0.5) %>% + cal_estimate_isotonic_boot(Class, group = group) + + expect_cal_type(sl_boot_group, "binary") + expect_cal_method(sl_boot_group, "Bootstrapped Isotonic Regression") + expect_snapshot(print(sl_boot_group)) + expect_false(are_groups_configs(sl_boot_group)) + + expect_snapshot_error( + segment_logistic %>% + dplyr::mutate(group1 = 1, group2 = 2) %>% + cal_estimate_isotonic_boot(Class, group = c(group1, group2)) + ) + + isobt_configs <- + bin_with_configs() %>% + cal_estimate_isotonic_boot(truth = Class) + expect_true(are_groups_configs(isobt_configs)) + + mltm_configs <- + mnl_with_configs() %>% + cal_estimate_isotonic_boot(truth = obs, estimate = c(VF:L)) + expect_true(are_groups_configs(mltm_configs)) + }) test_that("Isotonic Bootstrapped estimates work - tune_results", { @@ -94,11 +230,27 @@ test_that("Isotonic Bootstrapped estimates work - tune_results", { expect_cal_type(tl_isotonic, "binary") expect_cal_method(tl_isotonic, "Bootstrapped Isotonic Regression") expect_snapshot(print(tl_isotonic)) + expect_true(are_groups_configs(tl_isotonic)) expect_equal( testthat_cal_binary_count(), nrow(cal_apply(testthat_cal_binary(), tl_isotonic)) ) + + # ------------------------------------------------------------------------------ + # multinomial outcomes + + set.seed(100) + mtnl_isotonic <- cal_estimate_isotonic_boot(testthat_cal_multiclass()) + expect_cal_type(mtnl_isotonic, "one_vs_all") + expect_cal_method(mtnl_isotonic, "Bootstrapped Isotonic Regression") + expect_snapshot(print(mtnl_isotonic)) + expect_true(are_groups_configs(mtnl_isotonic)) + + expect_equal( + testthat_cal_multiclass_count(), + nrow(cal_apply(testthat_cal_multiclass(), mtnl_isotonic)) + ) }) # ----------------------------------- Beta ------------------------------------- @@ -108,6 +260,31 @@ test_that("Beta estimates work - data.frame", { expect_cal_method(sl_beta, "Beta") expect_cal_rows(sl_beta) expect_snapshot(print(sl_beta)) + + sl_beta_group <- segment_logistic %>% + dplyr::mutate(group = .pred_poor > 0.5) %>% + cal_estimate_beta(Class, smooth = FALSE, group = group) + + expect_cal_type(sl_beta_group, "binary") + expect_cal_method(sl_beta_group, "Beta") + expect_cal_rows(sl_beta_group) + expect_snapshot(print(sl_beta_group)) + + expect_snapshot_error( + segment_logistic %>% + dplyr::mutate(group1 = 1, group2 = 2) %>% + cal_estimate_beta(Class, smooth = FALSE, group = c(group1, group2)) + ) + + beta_configs <- + bin_with_configs() %>% + cal_estimate_beta(truth = Class) + expect_true(are_groups_configs(beta_configs)) + + mltm_configs <- + mnl_with_configs() %>% + cal_estimate_beta(truth = obs, estimate = c(VF:L)) + expect_true(are_groups_configs(mltm_configs)) }) test_that("Beta estimates work - tune_results", { @@ -115,11 +292,29 @@ test_that("Beta estimates work - tune_results", { expect_cal_type(tl_beta, "binary") expect_cal_method(tl_beta, "Beta") expect_snapshot(print(tl_beta)) + expect_true(are_groups_configs(tl_beta)) expect_equal( testthat_cal_binary_count(), nrow(cal_apply(testthat_cal_binary(), tl_beta)) ) + + # ------------------------------------------------------------------------------ + # multinomial outcomes + + set.seed(100) + suppressWarnings( + mtnl_isotonic <- cal_estimate_beta(testthat_cal_multiclass()) + ) + expect_cal_type(mtnl_isotonic, "one_vs_all") + expect_cal_method(mtnl_isotonic, "Beta") + expect_snapshot(print(mtnl_isotonic)) + expect_true(are_groups_configs(mtnl_isotonic)) + + expect_equal( + testthat_cal_multiclass_count(), + nrow(cal_apply(testthat_cal_multiclass(), mtnl_isotonic)) + ) }) # ------------------------------ Multinomial ----------------------------------- @@ -135,6 +330,26 @@ test_that("Multinomial estimates work - data.frame", { expect_cal_method(sp_smth_multi, "Multinomial") expect_cal_rows(sp_smth_multi, n = 110) expect_snapshot(print(sp_smth_multi)) + + sl_multi_group <- species_probs %>% + dplyr::mutate(group = .pred_bobcat > 0.5) %>% + cal_estimate_multinomial(Species, smooth = FALSE, group = group) + + expect_cal_type(sl_multi_group, "multiclass") + expect_cal_method(sl_multi_group, "Multinomial") + expect_cal_rows(sl_multi_group, n = 110) + expect_snapshot(print(sl_multi_group)) + + expect_snapshot_error( + species_probs %>% + dplyr::mutate(group1 = 1, group2 = 2) %>% + cal_estimate_multinomial(Species, smooth = FALSE, group = c(group1, group2)) + ) + + mltm_configs <- + mnl_with_configs() %>% + cal_estimate_multinomial(truth = obs, estimate = c(VF:L), smooth = FALSE) + expect_true(are_groups_configs(mltm_configs)) }) test_that("Multinomial estimates work - tune_results", { @@ -142,6 +357,7 @@ test_that("Multinomial estimates work - tune_results", { expect_cal_type(tl_multi, "multiclass") expect_cal_method(tl_multi, "Multinomial") expect_snapshot(print(tl_multi)) + expect_true(are_groups_configs(tl_multi)) expect_equal( testthat_cal_multiclass() %>% @@ -181,6 +397,28 @@ test_that("Linear estimates work - data.frame", { expect_cal_estimate(sl_logistic, "butchered_glm") expect_cal_rows(sl_logistic, 2000) expect_snapshot(print(sl_logistic)) + expect_false(are_groups_configs(sl_logistic)) + + sl_logistic_group <- boosting_predictions_oob %>% + dplyr::mutate(group = .pred > 0.5) %>% + cal_estimate_linear(outcome, smooth = FALSE, group = group) + + expect_cal_type(sl_logistic_group, "regression") + expect_cal_method(sl_logistic_group, "Linear") + expect_cal_estimate(sl_logistic_group, "butchered_glm") + expect_cal_rows(sl_logistic_group, 2000) + expect_snapshot(print(sl_logistic_group)) + + expect_snapshot_error( + boosting_predictions_oob %>% + dplyr::mutate(group1 = 1, group2 = 2) %>% + cal_estimate_linear(outcome, smooth = FALSE, group = c(group1, group2)) + ) + + lin_configs <- + reg_with_configs() %>% + cal_estimate_linear(truth = solubility, estimate = prediction, smooth = FALSE) + expect_true(are_groups_configs(lin_configs)) }) test_that("Linear estimates work - tune_results", { @@ -189,6 +427,7 @@ test_that("Linear estimates work - tune_results", { expect_cal_method(tl_linear, "Linear") expect_cal_estimate(tl_linear, "butchered_glm") expect_snapshot(print(tl_linear)) + expect_true(are_groups_configs(tl_linear)) }) # ----------------------------- Linear Spline -------------------------------- @@ -199,6 +438,27 @@ test_that("Linear spline estimates work - data.frame", { expect_cal_estimate(sl_gam, "butchered_gam") expect_cal_rows(sl_gam, 2000) expect_snapshot(print(sl_gam)) + + sl_gam_group <- boosting_predictions_oob %>% + dplyr::mutate(group = .pred > 0.5) %>% + cal_estimate_linear(outcome, group = group) + + expect_cal_type(sl_gam_group, "regression") + expect_cal_method(sl_gam_group, "Linear Spline") + expect_cal_estimate(sl_gam_group, "butchered_gam") + expect_cal_rows(sl_gam_group, 2000) + expect_snapshot(print(sl_gam_group)) + + expect_snapshot_error( + boosting_predictions_oob %>% + dplyr::mutate(group1 = 1, group2 = 2) %>% + cal_estimate_linear(outcome, group = c(group1, group2)) + ) + + lin_configs <- + reg_with_configs() %>% + cal_estimate_linear(truth = solubility, estimate = prediction, smooth = TRUE) + expect_true(are_groups_configs(lin_configs)) }) test_that("Linear spline estimates work - tune_results", { @@ -207,6 +467,7 @@ test_that("Linear spline estimates work - tune_results", { expect_cal_method(tl_gam, "Linear Spline") expect_cal_estimate(tl_gam, "butchered_gam") expect_snapshot(print(tl_gam)) + expect_true(are_groups_configs(tl_gam)) expect_equal( testthat_cal_reg_count(), diff --git a/tests/testthat/test-cal-plot.R b/tests/testthat/test-cal-plot.R index 6bcaf4c8..cb0ecd0f 100644 --- a/tests/testthat/test-cal-plot.R +++ b/tests/testthat/test-cal-plot.R @@ -23,13 +23,31 @@ test_that("Binary breaks functions work", { "ggplot" ) - expect_snapshot( - error = TRUE, - testthat_cal_binary() %>% - tune::collect_predictions() %>% - cal_plot_breaks(class, estimate = .pred_class_1) + brks_configs <- + bin_with_configs() %>% cal_plot_breaks(truth = Class, estimate = .pred_good) + expect_true(has_facet(brks_configs)) +}) + +test_that("Binary breaks functions work with group argument", { + res <- segment_logistic %>% + dplyr::mutate(id = dplyr::row_number() %% 2) %>% + cal_plot_breaks(Class, .pred_good, group = id) + + expect_s3_class( + res, + "ggplot" + ) + + expect_snapshot_plot( + "cal_plot_breaks-df-group", + print(res) ) + expect_snapshot_error( + segment_logistic %>% + dplyr::mutate(group1 = 1, group2 = 2) %>% + cal_plot_breaks(Class, .pred_good, group = c(group1, group2)) + ) }) test_that("Multi-class breaks functions work", { @@ -53,21 +71,29 @@ test_that("Multi-class breaks functions work", { seq(0.05, 0.95, by = 0.10) ) - expect_s3_class( - cal_plot_breaks(testthat_cal_multiclass()), - "ggplot" - ) + multi_configs <- cal_plot_breaks(testthat_cal_multiclass()) + # should be faceted by .config and class + expect_s3_class(multi_configs, "ggplot") + expect_true(inherits(multi_configs$facet, "FacetGrid")) expect_error( cal_plot_breaks(species_probs, Species, event_level = "second") ) - expect_snapshot( - error = TRUE, - testthat_cal_multiclass() %>% - tune::collect_predictions() %>% - cal_plot_breaks(class, estimate = .pred_class_1) - ) + # ------------------------------------------------------------------------------ + # multinomial outcome, binary logistic plots + + multi_configs_from_tune <- + testthat_cal_multiclass() %>% cal_plot_breaks() + expect_s3_class(multi_configs_from_tune, "ggplot") + # should be faceted by .config and class + expect_true(inherits(multi_configs_from_tune$facet, "FacetGrid")) + + multi_configs_from_df <- + mnl_with_configs() %>% cal_plot_breaks(truth = obs, estimate = c(VF:L)) + expect_s3_class(multi_configs_from_df, "ggplot") + # should be faceted by .config and class + expect_true(inherits(multi_configs_from_df$facet, "FacetGrid")) }) test_that("Binary logistic functions work", { @@ -89,6 +115,7 @@ test_that("Binary logistic functions work", { x21 <- cal_plot_logistic(segment_logistic, Class, .pred_good) expect_s3_class(x21, "ggplot") + expect_false(has_facet(x21)) x22 <- .cal_table_logistic(testthat_cal_binary()) @@ -116,6 +143,7 @@ test_that("Binary logistic functions work", { x23 <- cal_plot_logistic(testthat_cal_binary()) expect_s3_class(x23, "ggplot") + expect_true(has_facet(x23)) x24 <- .cal_table_logistic(segment_logistic, Class, .pred_good, smooth = FALSE) @@ -141,12 +169,53 @@ test_that("Binary logistic functions work", { nrow(x25) ) - expect_snapshot( - error = TRUE, - testthat_cal_binary() %>% - tune::collect_predictions() %>% - cal_plot_logistic(class, estimate = .pred_class_1) + lgst_configs <- + bin_with_configs() %>% cal_plot_logistic(truth = Class, estimate = .pred_good) + expect_true(has_facet(lgst_configs)) + + # ------------------------------------------------------------------------------ + # multinomial outcome, binary logistic plots + + multi_configs_from_tune <- + testthat_cal_multiclass() %>% cal_plot_logistic(smooth = FALSE) + expect_s3_class(multi_configs_from_tune, "ggplot") + # should be faceted by .config and class + expect_true(inherits(multi_configs_from_tune$facet, "FacetGrid")) + + + multi_configs_from_df <- + mnl_with_configs() %>% cal_plot_logistic(truth = obs, estimate = c(VF:L)) + expect_s3_class(multi_configs_from_df, "ggplot") + # should be faceted by .config and class + expect_true(inherits(multi_configs_from_df$facet, "FacetGrid")) + +}) + +test_that("Binary logistic functions work with group argument", { + res <- segment_logistic %>% + dplyr::mutate(id = dplyr::row_number() %% 2) %>% + cal_plot_logistic(Class, .pred_good, group = id) + + expect_s3_class( + res, + "ggplot" ) + expect_true(has_facet(res)) + + expect_snapshot_plot( + "cal_plot_logistic-df-group", + print(res) + ) + + expect_snapshot_error( + segment_logistic %>% + dplyr::mutate(group1 = 1, group2 = 2) %>% + cal_plot_logistic(Class, .pred_good, group = c(group1, group2)) + ) + + lgst_configs <- + bin_with_configs() %>% cal_plot_logistic(truth = Class, estimate = .pred_good) + expect_true(has_facet(lgst_configs)) }) test_that("Binary windowed functions work", { @@ -182,6 +251,7 @@ test_that("Binary windowed functions work", { x31 <- cal_plot_windowed(segment_logistic, Class, .pred_good) expect_s3_class(x31, "ggplot") + expect_false(has_facet(x31)) x32 <- .cal_table_windowed( testthat_cal_binary(), @@ -214,13 +284,28 @@ test_that("Binary windowed functions work", { x33 <- cal_plot_windowed(testthat_cal_binary()) expect_s3_class(x33, "ggplot") + expect_true(has_facet(x33)) - expect_snapshot( - error = TRUE, - testthat_cal_binary() %>% - tune::collect_predictions() %>% - cal_plot_windowed(class, estimate = .pred_class_1) - ) + win_configs <- + bin_with_configs() %>% cal_plot_windowed(truth = Class, estimate = .pred_good) + expect_true(has_facet(win_configs)) + + + # ------------------------------------------------------------------------------ + # multinomial outcome, binary windowed plots + + multi_configs_from_tune <- + testthat_cal_multiclass() %>% cal_plot_windowed() + expect_s3_class(multi_configs_from_tune, "ggplot") + # should be faceted by .config and class + expect_true(inherits(multi_configs_from_tune$facet, "FacetGrid")) + + + multi_configs_from_df <- + mnl_with_configs() %>% cal_plot_windowed(truth = obs, estimate = c(VF:L)) + expect_s3_class(multi_configs_from_df, "ggplot") + # should be faceted by .config and class + expect_true(inherits(multi_configs_from_df$facet, "FacetGrid")) }) test_that("Event level handling works", { @@ -359,15 +444,47 @@ test_that("regression functions work", { "rs-scat-group-opts", print(cal_plot_regression(obj), alpha = 1/5, smooth = FALSE) ) - - expect_snapshot( - error = TRUE, - obj %>% - tune::collect_predictions() %>% - cal_plot_windowed(outcome, estimate = .pred) - ) expect_snapshot_plot( "df-scat-lin", print(cal_plot_regression(boosting_predictions_oob, outcome, .pred, smooth = FALSE)) ) }) + +# ------------------------------------------------------------------------------ + +test_that("don't facet if there is only one .config", { + class_data <- testthat_cal_binary() + + class_data$.predictions <- lapply( + class_data$.predictions, + function(x) dplyr::filter(x, .config == "Preprocessor1_Model1") + ) + + res_breaks <- cal_plot_breaks(class_data) + + expect_null(res_breaks$data[[".config"]]) + expect_s3_class(res_breaks, "ggplot") + + res_logistic <- cal_plot_logistic(class_data) + + expect_null(res_logistic$data[[".config"]]) + expect_s3_class(res_logistic, "ggplot") + + res_windowed <- cal_plot_windowed(class_data) + + expect_null(res_windowed$data[[".config"]]) + expect_s3_class(res_windowed, "ggplot") + + reg_data <- testthat_cal_reg() + + reg_data$.predictions <- lapply( + reg_data$.predictions, + function(x) dplyr::filter(x, .config == "Preprocessor01_Model1") + ) + + res_regression <- cal_plot_regression(reg_data) + + expect_null(res_regression$data[[".config"]]) + expect_s3_class(res_regression, "ggplot") +}) +