From e601d1c549ed04f56563d5924a83de20bd28e434 Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 30 May 2024 11:15:19 -0400 Subject: [PATCH] changes for #141 (#148) --- NEWS.md | 2 ++ R/conformal_infer_cv.R | 2 +- tests/testthat/_snaps/conformal-intervals.md | 4 +++ tests/testthat/test-conformal-intervals.R | 35 ++++++++++++++++++++ 4 files changed, 42 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 6d0382bf..b500008e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,8 @@ * A new function `bound_prediction()` is available to constrain the values of a numeric prediction (#142). +* Fixed an error in `int_conformal_cv()` when grouped resampling was used (#141). + # probably 1.0.3 * Fixed a bug where the grouping for calibration methods was sensitive to the type of the grouping variables (#127). diff --git a/R/conformal_infer_cv.R b/R/conformal_infer_cv.R index fa84dcda..31910f66 100644 --- a/R/conformal_infer_cv.R +++ b/R/conformal_infer_cv.R @@ -221,7 +221,7 @@ new_infer_cv <- function(models, resid) { check_resampling <- function(x) { rs <- attr(x, "rset_info") - if (rs$att$class != "vfold_cv") { + if (any(rs$att$class != "vfold_cv") | any(grepl("group_", rs$att$class))) { msg <- paste0( "The data were resampled using ", rs$label, ". This method was developed for V-fold cross-validation. Interval ", diff --git a/tests/testthat/_snaps/conformal-intervals.md b/tests/testthat/_snaps/conformal-intervals.md index 01d606c5..7ce6ee33 100644 --- a/tests/testthat/_snaps/conformal-intervals.md +++ b/tests/testthat/_snaps/conformal-intervals.md @@ -193,3 +193,7 @@ Error in `control_conformal_full()`: ! `method` must be one of "iterative" or "grid", not "rock-paper-scissors". +# group resampling to conformal CV intervals + + The data were resampled using Group 2-fold cross-validation. This method was developed for V-fold cross-validation. Interval coverage is unknown for your resampling method. + diff --git a/tests/testthat/test-conformal-intervals.R b/tests/testthat/test-conformal-intervals.R index b2042af9..071ae1c1 100644 --- a/tests/testthat/test-conformal-intervals.R +++ b/tests/testthat/test-conformal-intervals.R @@ -237,3 +237,38 @@ test_that("conformal control", { expect_snapshot(dput(control_conformal_full(max_iter = 2))) expect_snapshot(error = TRUE, control_conformal_full(method = "rock-paper-scissors")) }) + + +test_that("group resampling to conformal CV intervals", { + skip_if_not_installed("modeldata") + skip_if_not_installed("nnet") + + make_data <- function(n, std_dev = 1 / 5) { + tibble(x = runif(n, min = -1)) %>% + mutate( + y = (x^3) + 2 * exp(-6 * (x - 0.3)^2), + y = y + rnorm(n, sd = std_dev) + ) + } + + n <- 100 + set.seed(8383) + train_data <- make_data(n) %>% + mutate(color = sample(c('red', 'blue'), n(), replace = TRUE)) + + set.seed(484) + nnet_wflow <- + workflow(y ~ x, mlp(hidden_units = 2) %>% set_mode("regression")) + + group_folds <- group_vfold_cv(train_data, group = color) + + ctrl <- control_resamples(save_pred = TRUE, extract = I) + + group_nnet_rs <- + nnet_wflow %>% + fit_resamples(group_folds, control = ctrl) + + expect_snapshot_warning(int_conformal_cv(group_nnet_rs)) + +}) +