Skip to content

Commit

Permalink
check compatibility of initial and param_info (#97)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch authored Jan 25, 2024
1 parent b18c9c5 commit e65925b
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 0 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# finetune (development version)

* Improved error message from `tune_sim_anneal()` when values in the supplied `param_info` do not encompass all values evaluated in the `initial` grid. This most often happens when a user mistakenly supplies different parameter sets to the function that generated the initial results and `tune_sim_anneal()`.

* Fixed bug where `tune_sim_anneal()` would fail when supplied parameters needing finalization. The function will now finalize needed parameter ranges internally (#39).

* Fixed bug where packages specified in `control_race(pkgs)` were not actually loaded in `tune_race_anova()` (#74).
Expand Down
19 changes: 19 additions & 0 deletions R/sim_anneal_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,30 @@ random_real_neighbor <- function(current, hist_values, pset, retain = 1,

encode_set_backwards <- function(x, pset, ...) {
pset <- pset[pset$id %in% names(x), ]
mapply(check_backwards_encode, pset$object, x, pset$id,
SIMPLIFY = FALSE, USE.NAMES = FALSE)
new_vals <- purrr::map2(pset$object, x, dials::encode_unit, direction = "backward")
names(new_vals) <- names(x)
tibble::as_tibble(new_vals)
}

check_backwards_encode <- function(x, value, id) {
if (!dials::has_unknowns(x)) {
compl <- value[!is.na(value)]
if (any(compl < 0) | any(compl > 1)) {
cli::cli_abort(c(
"!" = "The range for parameter {.val {noquote(id)}} used when \\
generating initial results isn't compatible with the range \\
supplied in {.arg param_info}.",
"i" = "Possible values of parameters in {.arg param_info} should \\
encompass all values evaluated in the initial grid."
),
call = rlang::call2("tune_sim_anneal()")
)
}
}
}

sample_by_distance <- function(candidates, existing, retain, pset) {
if (nrow(existing) > 0) {
existing <- tune::encode_set(existing, pset, as_matrix = TRUE)
Expand Down
15 changes: 15 additions & 0 deletions tests/testthat/_snaps/sa-overall.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,18 @@
9 - discard suboptimal roc_auc=0.84525 (+/-0.007793)
10 ( ) accept suboptimal roc_auc=0.84383 (+/-0.00773)

# incompatible parameter objects

Code
res <- tune_sim_anneal(car_wflow, param_info = parameter_set_with_smaller_range,
resamples = car_folds, initial = tune_res_with_bigger_range, iter = 2)
Message
Optimizing rmse
Condition
Error in `tune_sim_anneal()`:
! The range for parameter mtry used when generating initial results isn't compatible with the range supplied in `param_info`.
i Possible values of parameters in `param_info` should encompass all values evaluated in the initial grid.
Message
x Optimization stopped prematurely; returning current results.

50 changes: 50 additions & 0 deletions tests/testthat/test-sa-overall.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,56 @@ test_that("unfinalized parameters", {
})
})

test_that("incompatible parameter objects", {
skip_on_cran()

skip_if_not_installed("ranger")
skip_if_not_installed("modeldata")
skip_if_not_installed("rsample")

rf_spec <- parsnip::rand_forest(mode = "regression", mtry = tune::tune())

set.seed(1)
grid_with_bigger_range <-
dials::grid_latin_hypercube(dials::mtry(range = c(1, 16)))

set.seed(1)
car_folds <- rsample::vfold_cv(car_prices, v = 2)

car_wflow <- workflows::workflow() %>%
workflows::add_formula(Price ~ .) %>%
workflows::add_model(rf_spec)

set.seed(1)
tune_res_with_bigger_range <- tune::tune_grid(
car_wflow,
resamples = car_folds,
grid = grid_with_bigger_range
)

set.seed(1)
parameter_set_with_smaller_range <-
dials::parameters(dials::mtry(range = c(1, 5)))

scrub_best <- function(lines) {
has_best <- grepl("Initial best", lines)
lines[has_best] <- ""
lines
}

set.seed(1)
expect_snapshot(error = TRUE, transform = scrub_best, {
res <-
tune_sim_anneal(
car_wflow,
param_info = parameter_set_with_smaller_range,
resamples = car_folds,
initial = tune_res_with_bigger_range,
iter = 2
)
})
})

test_that("set event-level", {
# See issue 40
skip_if_not_installed("rpart")
Expand Down

0 comments on commit e65925b

Please sign in to comment.