diff --git a/DESCRIPTION b/DESCRIPTION index 7fc3b68f..4aa5712a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,5 +1,5 @@ Package: probably -Title: Tools for Post-Processing Class Probability Estimates +Title: Tools for Post-Processing Predicted Values Version: 1.0.1.9000 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")), @@ -31,7 +31,7 @@ Imports: rlang (>= 1.0.4), tidyr (>= 1.3.0), tidyselect (>= 1.1.2), - tune (>= 1.1.2.9020), + tune (>= 1.1.2), vctrs (>= 0.4.1), withr, workflows (>= 1.1.4), @@ -51,8 +51,6 @@ Suggests: rmarkdown, rsample, testthat (>= 3.0.0) -Remotes: - tidymodels/tune VignetteBuilder: knitr ByteCompile: true diff --git a/NEWS.md b/NEWS.md index 2c8d0757..e749169b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # probably (development version) +* Fixed a bug where the grouping for calibration methods was sensitive to the type of the grouping variables (#127). + # probably 1.0.1 * The conformal functions `int_conformal_infer_*()` were renamed to `int_conformal_*()`. diff --git a/R/cal-estimate-utils.R b/R/cal-estimate-utils.R index 15b1870a..b0626c06 100644 --- a/R/cal-estimate-utils.R +++ b/R/cal-estimate-utils.R @@ -187,34 +187,33 @@ tidyselect_cols <- function(.data, x) { ) } - # dplyr::group_map() does not pass the parent function's `...`, it overrides it # and there seems to be no way to change it. This function will split the the # data set by all the combination of the grouped variables. It will respect # any tidyeval variable calls made prior to calling the calibration split_dplyr_groups <- function(.data) { if (dplyr::is_grouped_df(.data)) { - .data %>% - dplyr::summarise(.groups = "drop") %>% - purrr::transpose() %>% - purrr::map(~ { - purrr::imap(.x, ~ expr(!!parse_expr(.y) == !!.x)) %>% - purrr::reduce(function(x, y) expr(!!x & !!y)) - }) %>% - purrr::map(~ { - df <- .data %>% - dplyr::filter(, !!.x) %>% - dplyr::ungroup() - - list( - data = df, - filter = .x, - rows = nrow(df) - ) - }) + grp_keys <- .data %>% dplyr::group_keys() + grp_keys <- purrr::map(grp_keys, as.character) + grp_var <- .data %>% dplyr::group_vars() + grp_data <- .data %>% tidyr::nest() + grp_filters <- purrr::map(grp_keys[[1]], ~ expr(!!parse_expr(grp_var) == !!.x)) + grp_n <- purrr::map_int(grp_data$data, nrow) + res <- vector(mode = "list", length = length(grp_filters)) + for (i in seq_along(res)) { + res[[i]]$data <- grp_data$data[[i]] + res[[i]]$filter <- grp_filters[[i]] + res[[i]]$rows <- grp_n[[i]] + } } else { - list(list(data = .data)) + res <- list(list(data = .data)) } + res +} + +create_filter_expr <- function(...) { + purrr::imap(..., ~ expr(!!parse_expr(.y) == !!.x)) %>% + purrr::reduce(function(x, y) expr(!!x & !!y)) } stop_null_parameters <- function(x) { diff --git a/tests/testthat/_snaps/cal-estimate.md b/tests/testthat/_snaps/cal-estimate.md index e64c2543..bdb7f406 100644 --- a/tests/testthat/_snaps/cal-estimate.md +++ b/tests/testthat/_snaps/cal-estimate.md @@ -151,7 +151,7 @@ Type: Binary Source class: Data Frame Data points: 1,010, split in 2 groups - Unique Predicted Values: 19 + Unique Predicted Values: 59 Truth variable: `Class` Estimate variables: `.pred_good` ==> good diff --git a/tests/testthat/test-cal-estimate.R b/tests/testthat/test-cal-estimate.R index 1d73d69a..68fdff05 100644 --- a/tests/testthat/test-cal-estimate.R +++ b/tests/testthat/test-cal-estimate.R @@ -112,6 +112,7 @@ test_that("Isotonic estimates work - data.frame", { expect_cal_rows(sl_isotonic) expect_snapshot(print(sl_isotonic)) + set.seed(100) sl_isotonic_group <- segment_logistic %>% dplyr::mutate(group = .pred_poor > 0.5) %>% cal_estimate_isotonic(Class, .by = group) @@ -121,17 +122,20 @@ test_that("Isotonic estimates work - data.frame", { expect_cal_rows(sl_isotonic_group) expect_snapshot(print(sl_isotonic_group)) + set.seed(100) expect_snapshot_error( segment_logistic %>% dplyr::mutate(group1 = 1, group2 = 2) %>% cal_estimate_isotonic(Class, .by = c(group1, group2)) ) + set.seed(100) iso_configs <- bin_with_configs() %>% cal_estimate_isotonic(truth = Class) expect_true(are_groups_configs(iso_configs)) + set.seed(100) mltm_configs <- mnl_with_configs() %>% cal_estimate_isotonic(truth = obs, estimate = c(VF:L)) diff --git a/tests/testthat/test-cal-plot.R b/tests/testthat/test-cal-plot.R index 5addeb86..ac44de90 100644 --- a/tests/testthat/test-cal-plot.R +++ b/tests/testthat/test-cal-plot.R @@ -543,6 +543,7 @@ test_that("regression functions work", { res <- cal_plot_regression(obj) expect_s3_class(res, "ggplot") + skip_if_not_installed("tune", "1.2.0") expect_equal( res$data[0,], dplyr::tibble(.pred = numeric(0), .row = numeric(0), @@ -572,6 +573,7 @@ test_that("regression functions work", { res <- print(cal_plot_regression(obj), alpha = 1 / 5, smooth = FALSE) expect_s3_class(res, "ggplot") + skip_if_not_installed("tune", "1.2.0") expect_equal( res$data[0,], dplyr::tibble(.pred = numeric(0), .row = numeric(0), diff --git a/tests/testthat/test-cal-validate.R b/tests/testthat/test-cal-validate.R index e7e8b908..8c8382bb 100644 --- a/tests/testthat/test-cal-validate.R +++ b/tests/testthat/test-cal-validate.R @@ -401,6 +401,8 @@ test_that("Multinomial calibration validation with `fit_resamples`", { names(val_with_pred), c("splits", "id", ".notes", ".predictions", ".metrics", ".metrics_cal", ".predictions_cal") ) + + skip_if_not_installed("tune", "1.2.0") expect_equal( names(val_with_pred$.predictions_cal[[1]]), c(".pred_one", ".pred_two", ".pred_three", ".row", "outcome", ".config", ".pred_class")