diff --git a/DESCRIPTION b/DESCRIPTION index 6a66fc1..ccb6dd2 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: finetune Title: Additional Functions for Model Tuning -Version: 1.1.0.9003 +Version: 1.1.0.9004 Authors@R: c( person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-2402-136X")), @@ -19,7 +19,7 @@ URL: https://github.com/tidymodels/finetune, BugReports: https://github.com/tidymodels/finetune/issues Depends: R (>= 3.5), - tune (>= 1.1.2.9001) + tune (>= 1.1.2.9004) Imports: cli, dials (>= 0.1.0), @@ -50,7 +50,7 @@ Suggests: testthat, yardstick Remotes: - tidymodels/tune#767 + tidymodels/tune Config/Needs/website: tidyverse/tidytemplate Config/testthat/edition: 3 Encoding: UTF-8 diff --git a/NEWS.md b/NEWS.md index 3c04be3..b00e99b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # finetune (development version) +* Updates based on the new version of tune, primarily for survival analysis models. + # finetune 1.1.0 * Various minor changes to keep up with developments in the tune and dplyr packages (#60) (#62) (#67) (#68). diff --git a/R/racing_helpers.R b/R/racing_helpers.R index b39ff0c..48e8092 100644 --- a/R/racing_helpers.R +++ b/R/racing_helpers.R @@ -387,7 +387,18 @@ harmonize_configs <- function(x, key) { x } -restore_tune <- function(x, y) { +# restore_tune() restores certain attributes (esp class) that are lost during +# racing when rows of the resampling object are filtered. +# `x` has class `tune_results`. `y` has the same structure but different +# attributes. +# About eval_time_target: `x` is from `tune_grid()`, which has no notion of a +# target evaluation time. https://github.com/tidymodels/tune/pull/782 defaults +# `eval_time_target` to NULL for grid tuning, resampling, and last fit objects. +# That's why `eval_time_target` is an argument to this function. It should be +# non-null for the resulting racing object but the value inherited from `x` is +# NULL unless we set it. + +restore_tune <- function(x, y, eval_time_target = NULL) { # With a smaller number of parameter combinations, the .config values may have # changed. We'll use the full set of parameters in `x` to adjust what is in # `y`. @@ -399,6 +410,7 @@ restore_tune <- function(x, y) { att <- attributes(x) att$row.names <- 1:(nrow(x) + nrow(y)) + att$eval_time_target <- eval_time_target att$class <- c("tune_race", "tune_results", class(tibble::tibble())) @@ -693,8 +705,23 @@ collect_metrics.tune_race <- function(x, summarize = TRUE, all_configs = FALSE, #' different resamples is likely to lead to inappropriate results. #' @export show_best.tune_race <- function(x, metric = NULL, n = 5, eval_time = NULL, ...) { + + if (!is.null(metric)) { + # What was used to judge the race and how are they being sorted now? + metrics <- tune::.get_tune_metrics(x) + opt_metric <- tune::first_metric(metrics) + opt_metric_name <- opt_metric$metric + if (metric[1] != opt_metric_name) { + cli::cli_warn("Metric {.val {opt_metric_name}} was used to evaluate model + candidates in the race but {.val {metric}} has been chosen + to rank the candidates. These results may not agree with the + race.") + } + } + x <- dplyr::select(x, -.order) final_configs <- subset_finished_race(x) + res <- NextMethod(metric = metric, n = Inf, eval_time = eval_time, ...) res$.ranked <- 1:nrow(res) res <- dplyr::inner_join(res, final_configs, by = ".config") @@ -720,27 +747,10 @@ subset_finished_race <- function(x) { # ------------------------------------------------------------------------------ # Log the objective function used for racing -racing_obj_log <- function(x, metrics, control, eval_time = NULL) { - metric_info <- tibble::as_tibble(metrics) - analysis_metric <- metric_info$metric[1] - analysis_max <- metric_info$direction[1] == "maximize" - is_dyn <- metric_info$class[1] == "dynamic_survival_metric" - if (is_dyn) { - metrics_time <- eval_time[1] - } else { - metrics_time <- NULL - } - +racing_obj_log <- function(analysis_metric, direction, control, metrics_time = NULL) { cols <- tune::get_tune_colors() if (control$verbose_elim) { - msg <- - paste( - "Racing will", - ifelse(analysis_max, "maximize", "minimize"), - "the", - analysis_metric, - "metric" - ) + msg <- paste("Racing will", direction, "the", analysis_metric, "metric") if (!is.null(metrics_time)) { msg <- paste(msg, "at time", format(metrics_time, digits = 3)) diff --git a/R/tune_race_anova.R b/R/tune_race_anova.R index 66eaf19..2c54236 100644 --- a/R/tune_race_anova.R +++ b/R/tune_race_anova.R @@ -214,7 +214,7 @@ tune_race_anova.workflow <- tune_race_anova_workflow <- function(object, resamples, param_info = NULL, grid = 10, metrics = NULL, - control = control_race(), eval_time = NULL) { + control = control_race(), eval_time = NULL, call = caller_env()) { rlang::check_installed("lme4") tune::initialize_catalog(control = control) @@ -244,14 +244,19 @@ tune_race_anova_workflow <- ) param_names <- tune::.get_tune_parameter_names(res) + metrics <- tune::.get_tune_metrics(res) + metrics <- tune::check_metrics_arg(metrics, object, call = call) + opt_metric <- tune::first_metric(metrics) + opt_metric_name <- opt_metric$metric + maximize <- opt_metric$direction == "maximize" - racing_obj_log(res, metrics, control, eval_time) + eval_time <- tune::check_eval_time_arg(eval_time, metrics, call = call) + opt_metric_time <- tune::first_eval_time(metrics, opt_metric_name, eval_time) - analysis_metric <- names(attr(metrics, "metrics"))[1] - metrics_time <- eval_time[1] + racing_obj_log(opt_metric_name, opt_metric$direction, control, opt_metric_time) - filters_results <- test_parameters_gls(res, control$alpha, metrics_time) + filters_results <- test_parameters_gls(res, control$alpha, opt_metric_time) n_grid <- nrow(filters_results) log_final <- TRUE @@ -267,11 +272,11 @@ tune_race_anova_workflow <- if (nrow(new_grid) > 1) { tmp_resamples <- restore_rset(resamples, rs) - log_racing(control, filters_results, res$splits, n_grid, analysis_metric) + log_racing(control, filters_results, res$splits, n_grid, opt_metric_name) } else { tmp_resamples <- restore_rset(resamples, rs:B) if (log_final) { - log_racing(control, filters_results, res$splits, n_grid, analysis_metric) + log_racing(control, filters_results, res$splits, n_grid, opt_metric_name) } log_final <- FALSE } @@ -288,10 +293,10 @@ tune_race_anova_workflow <- eval_time = eval_time ) - res <- restore_tune(res, tmp_res) + res <- restore_tune(res, tmp_res, opt_metric_time) if (nrow(new_grid) > 1) { - filters_results <- test_parameters_gls(res, control$alpha, metrics_time) + filters_results <- test_parameters_gls(res, control$alpha, opt_metric_time) if (sum(filters_results$pass) == 2 & num_ties >= control$num_ties) { filters_results <- tie_breaker(res, control, eval_time = eval_time) } diff --git a/R/tune_race_win_loss.R b/R/tune_race_win_loss.R index cbd6d55..8240e51 100644 --- a/R/tune_race_win_loss.R +++ b/R/tune_race_win_loss.R @@ -211,7 +211,7 @@ tune_race_win_loss.workflow <- function(object, resamples, ..., param_info = NUL tune_race_win_loss_workflow <- function(object, resamples, param_info = NULL, grid = 10, metrics = NULL, - control = control_race(), eval_time = NULL) { + control = control_race(), eval_time = NULL, call = caller_env()) { rlang::check_installed("BradleyTerry2") B <- nrow(resamples) @@ -239,12 +239,17 @@ tune_race_win_loss_workflow <- param_names <- tune::.get_tune_parameter_names(res) metrics <- tune::.get_tune_metrics(res) - racing_obj_log(res, metrics, control, eval_time) + metrics <- tune::check_metrics_arg(metrics, object, call = call) + opt_metric <- tune::first_metric(metrics) + opt_metric_name <- opt_metric$metric + maximize <- opt_metric$direction == "maximize" - analysis_metric <- names(attr(metrics, "metrics"))[1] - metrics_time <- eval_time[1] + eval_time <- tune::check_eval_time_arg(eval_time, metrics, call = call) + opt_metric_time <- tune::first_eval_time(metrics, opt_metric_name, eval_time) - filters_results <- test_parameters_bt(res, control$alpha, metrics_time) + racing_obj_log(opt_metric_name, opt_metric$direction, control, opt_metric_time) + + filters_results <- test_parameters_bt(res, control$alpha, opt_metric_time) n_grid <- nrow(filters_results) log_final <- TRUE @@ -260,11 +265,11 @@ tune_race_win_loss_workflow <- if (nrow(new_grid) > 1) { tmp_resamples <- restore_rset(resamples, rs) - log_racing(control, filters_results, res$splits, n_grid, analysis_metric) + log_racing(control, filters_results, res$splits, n_grid, opt_metric_name) } else { tmp_resamples <- restore_rset(resamples, rs:B) if (log_final) { - log_racing(control, filters_results, res$splits, n_grid, analysis_metric) + log_racing(control, filters_results, res$splits, n_grid, opt_metric_name) } log_final <- FALSE } @@ -280,10 +285,10 @@ tune_race_win_loss_workflow <- control = grid_control, eval_time = eval_time ) - res <- restore_tune(res, tmp_res) + res <- restore_tune(res, tmp_res, opt_metric_time) if (nrow(new_grid) > 1) { - filters_results <- test_parameters_bt(res, control$alpha, metrics_time) + filters_results <- test_parameters_bt(res, control$alpha, opt_metric_time) if (sum(filters_results$pass) == 2 & num_ties >= control$num_ties) { filters_results <- tie_breaker(res, control) } diff --git a/R/tune_sim_anneal.R b/R/tune_sim_anneal.R index f160fa3..9f6439a 100644 --- a/R/tune_sim_anneal.R +++ b/R/tune_sim_anneal.R @@ -304,7 +304,8 @@ tune_sim_anneal.workflow <- tune_sim_anneal_workflow <- function(object, resamples, iter = 10, param_info = NULL, metrics = NULL, - initial = 5, control = control_sim_anneal(), eval_time = NULL) { + initial = 5, control = control_sim_anneal(), eval_time = NULL, + call = caller_env()) { start_time <- proc.time()[3] cols <- tune::get_tune_colors() @@ -314,10 +315,13 @@ tune_sim_anneal_workflow <- tune::check_rset(resamples) rset_info <- tune::pull_rset_attributes(resamples) - metrics <- tune::check_metrics(metrics, object) - metrics_name <- names(attr(metrics, "metrics"))[1] - metrics_time <- eval_time[1] - maximize <- attr(attr(metrics, "metrics")[[1]], "direction") == "maximize" + metrics <- tune::check_metrics_arg(metrics, object, call = call) + opt_metric <- tune::first_metric(metrics) + opt_metric_name <- opt_metric$metric + maximize <- opt_metric$direction == "maximize" + + eval_time <- tune::check_eval_time_arg(eval_time, metrics, call = call) + opt_metric_time <- tune::first_eval_time(metrics, opt_metric_name, eval_time) if (is.null(param_info)) { param_info <- extract_parameter_set_dials(object) @@ -362,6 +366,7 @@ tune_sim_anneal_workflow <- parameters = param_info, metrics = metrics, eval_time = eval_time, + eval_time_target = opt_metric_time, outcomes = y_names, rset_info = rset_info, workflow = object @@ -388,6 +393,7 @@ tune_sim_anneal_workflow <- parameters = param_info, metrics = metrics, eval_time = eval_time, + eval_time_target = opt_metric_time, outcomes = y_names, rset_info = rset_info, workflow = object @@ -397,16 +403,20 @@ tune_sim_anneal_workflow <- return(out) }) - cols <- tune::get_tune_colors() if (control$verbose_iter) { - cli::cli_bullets(cols$message$info(paste("Optimizing", metrics_name))) + msg <- paste("Optimizing", opt_metric_name) + if (!is.null(opt_metric_time)) { + msg <- paste(msg, "at evaluation time", format(opt_metric_time, digits = 3)) + } + cli::cli_bullets(msg) } + ## ----------------------------------------------------------------------------- - result_history <- initialize_history(unsummarized, metrics_time) + result_history <- initialize_history(unsummarized, opt_metric_time) best_param <- - tune::select_best(unsummarized, metric = metrics_name, eval_time = metrics_time) %>% + tune::select_best(unsummarized, metric = opt_metric_name, eval_time = opt_metric_time) %>% dplyr::mutate(.parent = NA_character_) grid_history <- best_param current_param <- best_param @@ -426,7 +436,7 @@ tune_sim_anneal_workflow <- x = result_history, max_iter = iter, maximize = maximize, - metric = metrics_name + metric = opt_metric_name ) for (i in (existing_iter + 1):iter) { @@ -451,17 +461,17 @@ tune_sim_anneal_workflow <- grid = new_grid %>% dplyr::select(-.config, -.parent), metrics = metrics, control = control_init, - eval_time = metrics_time + eval_time = eval_time ) %>% dplyr::mutate(.iter = i) %>% update_config(config = paste0("Iter", i), save_pred = control$save_pred) result_history <- result_history %>% - update_history(res, i, eval_time = metrics_time) %>% + update_history(res, i, eval_time = opt_metric_time) %>% sa_decide( parent = new_grid$.parent, - metric = metrics_name, + metric = opt_metric_name, maximize = maximize, coef = control$cooling_coef ) @@ -508,6 +518,7 @@ tune_sim_anneal_workflow <- parameters = param_info, metrics = metrics, eval_time = eval_time, + eval_time_target = opt_metric_time, outcomes = y_names, rset_info = rset_info, workflow = object @@ -520,7 +531,7 @@ tune_sim_anneal_workflow <- x = result_history, max_iter = iter, maximize = maximize, - metric = metrics_name + metric = opt_metric_name ) if (count_improve >= control$no_improve) { diff --git a/tests/testthat/_snaps/win-loss-overall.md b/tests/testthat/_snaps/win-loss-overall.md index 877fab3..9fea7f1 100644 --- a/tests/testthat/_snaps/win-loss-overall.md +++ b/tests/testthat/_snaps/win-loss-overall.md @@ -11,3 +11,11 @@ i Fold2, Repeat2: 0 eliminated; 5 candidates remain. i Fold3, Repeat2: 0 eliminated; 5 candidates remain. +# one player is really bad + + Code + best_res <- show_best(tuning_results) + Condition + Warning in `show_best()`: + No value of `metric` was given; "roc_auc" will be used. + diff --git a/tests/testthat/test-anova-overall.R b/tests/testthat/test-anova-overall.R index 063e00f..0f2ecf6 100644 --- a/tests/testthat/test-anova-overall.R +++ b/tests/testthat/test-anova-overall.R @@ -16,6 +16,7 @@ test_that("formula interface", { expect_true(nrow(collect_metrics(res)) < nrow(grid_mod) * 2) expect_equal(res, .Last.tune.result) expect_null(.get_tune_eval_times(res)) + expect_null(.get_tune_eval_time_target(res)) }) # ------------------------------------------------------------------------------ diff --git a/tests/testthat/test-sa-decision.R b/tests/testthat/test-sa-decision.R index 8106e5b..26c50c5 100644 --- a/tests/testthat/test-sa-decision.R +++ b/tests/testthat/test-sa-decision.R @@ -21,9 +21,10 @@ test_that("simulated annealing decisions", { outcomes = cart_outcomes, metrics = cart_metrics, eval_time = NULL, + eval_time_target = NULL, rset_info = cart_rset_info ) - iter_new_hist <- finetune:::update_history(iter_hist, iter_res, iter_val) + iter_new_hist <- finetune:::update_history(iter_hist, iter_res, iter_val, NULL) iter_new_hist$random[1:nrow(iter_new_hist)] <- cart_history$random[1:nrow(iter_new_hist)] expect_equal( diff --git a/tests/testthat/test-sa-overall.R b/tests/testthat/test-sa-overall.R index 4e03cd6..d666677 100644 --- a/tests/testthat/test-sa-overall.R +++ b/tests/testthat/test-sa-overall.R @@ -12,6 +12,7 @@ test_that("formula interface", { expect_true(nrow(collect_metrics(res)) == 6) expect_equal(res, .Last.tune.result) expect_null(.get_tune_eval_times(res)) + expect_null(.get_tune_eval_time_target(res)) }) # ------------------------------------------------------------------------------ diff --git a/tests/testthat/test-win-loss-filter.R b/tests/testthat/test-win-loss-filter.R index 1746c11..bc0195a 100644 --- a/tests/testthat/test-win-loss-filter.R +++ b/tests/testthat/test-win-loss-filter.R @@ -30,6 +30,7 @@ test_that("top-level win/loss filter interfaces", { expect_true(inherits(wl_mod, "tune_results")) expect_true(tibble::is_tibble((wl_mod))) expect_null(.get_tune_eval_times(wl_mod)) + expect_null(.get_tune_eval_time_target(wl_mod)) expect_silent({ set.seed(129) diff --git a/tests/testthat/test-win-loss-overall.R b/tests/testthat/test-win-loss-overall.R index 0cb70cc..6efda57 100644 --- a/tests/testthat/test-win-loss-overall.R +++ b/tests/testthat/test-win-loss-overall.R @@ -88,7 +88,7 @@ test_that("one player is really bad", { control = ctrl ) - # TODO Needs to be fixed in tune package - expect_true(nrow(show_best(tuning_results)) > 0) + expect_snapshot(best_res <- show_best(tuning_results)) + expect_true(nrow(best_res) == 1) })