From 637e56fe101ab68aff26fa095f2ddff2a4310a7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Tue, 20 Feb 2024 21:28:57 -0500 Subject: [PATCH 1/6] changes for #127 --- NEWS.md | 2 ++ R/cal-estimate-utils.R | 34 +++++++++++++++------------------- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/NEWS.md b/NEWS.md index 2c8d0757..e749169b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # probably (development version) +* Fixed a bug where the grouping for calibration methods was sensitive to the type of the grouping variables (#127). + # probably 1.0.1 * The conformal functions `int_conformal_infer_*()` were renamed to `int_conformal_*()`. diff --git a/R/cal-estimate-utils.R b/R/cal-estimate-utils.R index 15b1870a..c6e32c22 100644 --- a/R/cal-estimate-utils.R +++ b/R/cal-estimate-utils.R @@ -194,27 +194,23 @@ tidyselect_cols <- function(.data, x) { # any tidyeval variable calls made prior to calling the calibration split_dplyr_groups <- function(.data) { if (dplyr::is_grouped_df(.data)) { - .data %>% - dplyr::summarise(.groups = "drop") %>% - purrr::transpose() %>% - purrr::map(~ { - purrr::imap(.x, ~ expr(!!parse_expr(.y) == !!.x)) %>% - purrr::reduce(function(x, y) expr(!!x & !!y)) - }) %>% - purrr::map(~ { - df <- .data %>% - dplyr::filter(, !!.x) %>% - dplyr::ungroup() - - list( - data = df, - filter = .x, - rows = nrow(df) - ) - }) + grp_keys <- .data %>% dplyr::group_keys() + grp_var <- .data %>% dplyr::group_vars() + grp_sym <- rlang::sym(grp_var) + grp_data <- .data %>% tidyr::nest() + grp_filters <- + purrr::map(grp_keys[[1]], ~ expr(!!parse_expr(grp_var) == !!.x)) + grp_n <- purrr::map_int(grp_data$data, nrow) + res <- vector(mode = "list", length = length(grp_filters)) + for (i in seq_along(res)) { + res[[i]]$data <- grp_data$data[[i]] + res[[i]]$filter <- grp_filters[[i]] + res[[i]]$rows <- grp_n[[i]] + } } else { - list(list(data = .data)) + res <- list(list(data = .data)) } + res } stop_null_parameters <- function(x) { From 30dd9c283c657ba1cfb3fdaa753b31021c07edac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Thu, 22 Feb 2024 10:05:53 -0500 Subject: [PATCH 2/6] set seed to get reproducible test --- tests/testthat/_snaps/cal-estimate.md | 2 +- tests/testthat/test-cal-estimate.R | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/testthat/_snaps/cal-estimate.md b/tests/testthat/_snaps/cal-estimate.md index e64c2543..bdb7f406 100644 --- a/tests/testthat/_snaps/cal-estimate.md +++ b/tests/testthat/_snaps/cal-estimate.md @@ -151,7 +151,7 @@ Type: Binary Source class: Data Frame Data points: 1,010, split in 2 groups - Unique Predicted Values: 19 + Unique Predicted Values: 59 Truth variable: `Class` Estimate variables: `.pred_good` ==> good diff --git a/tests/testthat/test-cal-estimate.R b/tests/testthat/test-cal-estimate.R index 1d73d69a..68fdff05 100644 --- a/tests/testthat/test-cal-estimate.R +++ b/tests/testthat/test-cal-estimate.R @@ -112,6 +112,7 @@ test_that("Isotonic estimates work - data.frame", { expect_cal_rows(sl_isotonic) expect_snapshot(print(sl_isotonic)) + set.seed(100) sl_isotonic_group <- segment_logistic %>% dplyr::mutate(group = .pred_poor > 0.5) %>% cal_estimate_isotonic(Class, .by = group) @@ -121,17 +122,20 @@ test_that("Isotonic estimates work - data.frame", { expect_cal_rows(sl_isotonic_group) expect_snapshot(print(sl_isotonic_group)) + set.seed(100) expect_snapshot_error( segment_logistic %>% dplyr::mutate(group1 = 1, group2 = 2) %>% cal_estimate_isotonic(Class, .by = c(group1, group2)) ) + set.seed(100) iso_configs <- bin_with_configs() %>% cal_estimate_isotonic(truth = Class) expect_true(are_groups_configs(iso_configs)) + set.seed(100) mltm_configs <- mnl_with_configs() %>% cal_estimate_isotonic(truth = obs, estimate = c(VF:L)) From d0bb1b25f1c3dff4a8c7eb14ee6c094ab275c8d0 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 22 Feb 2024 15:50:32 -0500 Subject: [PATCH 3/6] Apply suggestions from code review Co-authored-by: Simon P. Couch --- R/cal-estimate-utils.R | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/R/cal-estimate-utils.R b/R/cal-estimate-utils.R index c6e32c22..c594c6a8 100644 --- a/R/cal-estimate-utils.R +++ b/R/cal-estimate-utils.R @@ -195,11 +195,11 @@ tidyselect_cols <- function(.data, x) { split_dplyr_groups <- function(.data) { if (dplyr::is_grouped_df(.data)) { grp_keys <- .data %>% dplyr::group_keys() + grp_keys <- purrr::map(grp_keys, as.character) grp_var <- .data %>% dplyr::group_vars() - grp_sym <- rlang::sym(grp_var) grp_data <- .data %>% tidyr::nest() grp_filters <- - purrr::map(grp_keys[[1]], ~ expr(!!parse_expr(grp_var) == !!.x)) + purrr::pmap(grp_keys, create_filter_expr) grp_n <- purrr::map_int(grp_data$data, nrow) res <- vector(mode = "list", length = length(grp_filters)) for (i in seq_along(res)) { @@ -213,6 +213,11 @@ split_dplyr_groups <- function(.data) { res } +create_filter_expr <- function(...) { + purrr::imap(..., ~ expr(!!parse_expr(.y) == !!.x)) %>% + purrr::reduce(function(x, y) expr(!!x & !!y)) +} + stop_null_parameters <- function(x) { if (!is.null(x)) { rlang::abort("The `parameters` argument is only valid for `tune_results`.") From e7f661b7c7c51868b3e5c051f947b0c49ce08f43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Thu, 22 Feb 2024 19:55:49 -0500 Subject: [PATCH 4/6] revert commit --- R/cal-estimate-utils.R | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/R/cal-estimate-utils.R b/R/cal-estimate-utils.R index c594c6a8..b0626c06 100644 --- a/R/cal-estimate-utils.R +++ b/R/cal-estimate-utils.R @@ -187,7 +187,6 @@ tidyselect_cols <- function(.data, x) { ) } - # dplyr::group_map() does not pass the parent function's `...`, it overrides it # and there seems to be no way to change it. This function will split the the # data set by all the combination of the grouped variables. It will respect @@ -198,8 +197,7 @@ split_dplyr_groups <- function(.data) { grp_keys <- purrr::map(grp_keys, as.character) grp_var <- .data %>% dplyr::group_vars() grp_data <- .data %>% tidyr::nest() - grp_filters <- - purrr::pmap(grp_keys, create_filter_expr) + grp_filters <- purrr::map(grp_keys[[1]], ~ expr(!!parse_expr(grp_var) == !!.x)) grp_n <- purrr::map_int(grp_data$data, nrow) res <- vector(mode = "list", length = length(grp_filters)) for (i in seq_along(res)) { From 24daedc150d07f2a73b6fba2bcd032f586d41890 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Thu, 22 Feb 2024 19:55:57 -0500 Subject: [PATCH 5/6] missing skips --- tests/testthat/test-cal-plot.R | 2 ++ tests/testthat/test-cal-validate.R | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/testthat/test-cal-plot.R b/tests/testthat/test-cal-plot.R index 5addeb86..ac44de90 100644 --- a/tests/testthat/test-cal-plot.R +++ b/tests/testthat/test-cal-plot.R @@ -543,6 +543,7 @@ test_that("regression functions work", { res <- cal_plot_regression(obj) expect_s3_class(res, "ggplot") + skip_if_not_installed("tune", "1.2.0") expect_equal( res$data[0,], dplyr::tibble(.pred = numeric(0), .row = numeric(0), @@ -572,6 +573,7 @@ test_that("regression functions work", { res <- print(cal_plot_regression(obj), alpha = 1 / 5, smooth = FALSE) expect_s3_class(res, "ggplot") + skip_if_not_installed("tune", "1.2.0") expect_equal( res$data[0,], dplyr::tibble(.pred = numeric(0), .row = numeric(0), diff --git a/tests/testthat/test-cal-validate.R b/tests/testthat/test-cal-validate.R index e7e8b908..8c8382bb 100644 --- a/tests/testthat/test-cal-validate.R +++ b/tests/testthat/test-cal-validate.R @@ -401,6 +401,8 @@ test_that("Multinomial calibration validation with `fit_resamples`", { names(val_with_pred), c("splits", "id", ".notes", ".predictions", ".metrics", ".metrics_cal", ".predictions_cal") ) + + skip_if_not_installed("tune", "1.2.0") expect_equal( names(val_with_pred$.predictions_cal[[1]]), c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class") From 96e0a5b24a061106d1b3173a547e3d98beaf583e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Thu, 22 Feb 2024 19:56:59 -0500 Subject: [PATCH 6/6] fix version number and change title --- DESCRIPTION | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 7fc3b68f..4aa5712a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,5 +1,5 @@ Package: probably -Title: Tools for Post-Processing Class Probability Estimates +Title: Tools for Post-Processing Predicted Values Version: 1.0.1.9000 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")), @@ -31,7 +31,7 @@ Imports: rlang (>= 1.0.4), tidyr (>= 1.3.0), tidyselect (>= 1.1.2), - tune (>= 1.1.2.9020), + tune (>= 1.1.2), vctrs (>= 0.4.1), withr, workflows (>= 1.1.4), @@ -51,8 +51,6 @@ Suggests: rmarkdown, rsample, testthat (>= 3.0.0) -Remotes: - tidymodels/tune VignetteBuilder: knitr ByteCompile: true