Skip to content

Commit

Permalink
changes for #141 (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
topepo authored May 30, 2024
1 parent 5aa7ecd commit e601d1c
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 1 deletion.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

* A new function `bound_prediction()` is available to constrain the values of a numeric prediction (#142).

* Fixed an error in `int_conformal_cv()` when grouped resampling was used (#141).

# probably 1.0.3

* Fixed a bug where the grouping for calibration methods was sensitive to the type of the grouping variables (#127).
Expand Down
2 changes: 1 addition & 1 deletion R/conformal_infer_cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ new_infer_cv <- function(models, resid) {

check_resampling <- function(x) {
rs <- attr(x, "rset_info")
if (rs$att$class != "vfold_cv") {
if (any(rs$att$class != "vfold_cv") | any(grepl("group_", rs$att$class))) {
msg <- paste0(
"The data were resampled using ", rs$label,
". This method was developed for V-fold cross-validation. Interval ",
Expand Down
4 changes: 4 additions & 0 deletions tests/testthat/_snaps/conformal-intervals.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,7 @@
Error in `control_conformal_full()`:
! `method` must be one of "iterative" or "grid", not "rock-paper-scissors".

# group resampling to conformal CV intervals

The data were resampled using Group 2-fold cross-validation. This method was developed for V-fold cross-validation. Interval coverage is unknown for your resampling method.

35 changes: 35 additions & 0 deletions tests/testthat/test-conformal-intervals.R
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,38 @@ test_that("conformal control", {
expect_snapshot(dput(control_conformal_full(max_iter = 2)))
expect_snapshot(error = TRUE, control_conformal_full(method = "rock-paper-scissors"))
})


test_that("group resampling to conformal CV intervals", {
skip_if_not_installed("modeldata")
skip_if_not_installed("nnet")

make_data <- function(n, std_dev = 1 / 5) {
tibble(x = runif(n, min = -1)) %>%
mutate(
y = (x^3) + 2 * exp(-6 * (x - 0.3)^2),
y = y + rnorm(n, sd = std_dev)
)
}

n <- 100
set.seed(8383)
train_data <- make_data(n) %>%
mutate(color = sample(c('red', 'blue'), n(), replace = TRUE))

set.seed(484)
nnet_wflow <-
workflow(y ~ x, mlp(hidden_units = 2) %>% set_mode("regression"))

group_folds <- group_vfold_cv(train_data, group = color)

ctrl <- control_resamples(save_pred = TRUE, extract = I)

group_nnet_rs <-
nnet_wflow %>%
fit_resamples(group_folds, control = ctrl)

expect_snapshot_warning(int_conformal_cv(group_nnet_rs))

})

0 comments on commit e601d1c

Please sign in to comment.