Skip to content

Commit

Permalink
Survival analysis updates (#88)
Browse files Browse the repository at this point in the history
* initial updates

* changes from tidymodels/tune#782

* basic tests for new tune functions
  • Loading branch information
topepo authored Dec 13, 2023
1 parent 2856107 commit a57178b
Show file tree
Hide file tree
Showing 12 changed files with 103 additions and 58 deletions.
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: finetune
Title: Additional Functions for Model Tuning
Version: 1.1.0.9003
Version: 1.1.0.9004
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre"),
comment = c(ORCID = "0000-0003-2402-136X")),
Expand All @@ -19,7 +19,7 @@ URL: https://github.com/tidymodels/finetune,
BugReports: https://github.com/tidymodels/finetune/issues
Depends:
R (>= 3.5),
tune (>= 1.1.2.9001)
tune (>= 1.1.2.9004)
Imports:
cli,
dials (>= 0.1.0),
Expand Down Expand Up @@ -50,7 +50,7 @@ Suggests:
testthat,
yardstick
Remotes:
tidymodels/tune#767
tidymodels/tune
Config/Needs/website: tidyverse/tidytemplate
Config/testthat/edition: 3
Encoding: UTF-8
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# finetune (development version)

* Updates based on the new version of tune, primarily for survival analysis models.

# finetune 1.1.0

* Various minor changes to keep up with developments in the tune and dplyr packages (#60) (#62) (#67) (#68).
Expand Down
50 changes: 30 additions & 20 deletions R/racing_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,18 @@ harmonize_configs <- function(x, key) {
x
}

restore_tune <- function(x, y) {
# restore_tune() restores certain attributes (esp class) that are lost during
# racing when rows of the resampling object are filtered.
# `x` has class `tune_results`. `y` has the same structure but different
# attributes.
# About eval_time_target: `x` is from `tune_grid()`, which has no notion of a
# target evaluation time. https://github.com/tidymodels/tune/pull/782 defaults
# `eval_time_target` to NULL for grid tuning, resampling, and last fit objects.
# That's why `eval_time_target` is an argument to this function. It should be
# non-null for the resulting racing object but the value inherited from `x` is
# NULL unless we set it.

restore_tune <- function(x, y, eval_time_target = NULL) {
# With a smaller number of parameter combinations, the .config values may have
# changed. We'll use the full set of parameters in `x` to adjust what is in
# `y`.
Expand All @@ -399,6 +410,7 @@ restore_tune <- function(x, y) {

att <- attributes(x)
att$row.names <- 1:(nrow(x) + nrow(y))
att$eval_time_target <- eval_time_target
att$class <- c("tune_race", "tune_results", class(tibble::tibble()))


Expand Down Expand Up @@ -693,8 +705,23 @@ collect_metrics.tune_race <- function(x, summarize = TRUE, all_configs = FALSE,
#' different resamples is likely to lead to inappropriate results.
#' @export
show_best.tune_race <- function(x, metric = NULL, n = 5, eval_time = NULL, ...) {

if (!is.null(metric)) {
# What was used to judge the race and how are they being sorted now?
metrics <- tune::.get_tune_metrics(x)
opt_metric <- tune::first_metric(metrics)
opt_metric_name <- opt_metric$metric
if (metric[1] != opt_metric_name) {
cli::cli_warn("Metric {.val {opt_metric_name}} was used to evaluate model
candidates in the race but {.val {metric}} has been chosen
to rank the candidates. These results may not agree with the
race.")
}
}

x <- dplyr::select(x, -.order)
final_configs <- subset_finished_race(x)

res <- NextMethod(metric = metric, n = Inf, eval_time = eval_time, ...)
res$.ranked <- 1:nrow(res)
res <- dplyr::inner_join(res, final_configs, by = ".config")
Expand All @@ -720,27 +747,10 @@ subset_finished_race <- function(x) {
# ------------------------------------------------------------------------------
# Log the objective function used for racing

racing_obj_log <- function(x, metrics, control, eval_time = NULL) {
metric_info <- tibble::as_tibble(metrics)
analysis_metric <- metric_info$metric[1]
analysis_max <- metric_info$direction[1] == "maximize"
is_dyn <- metric_info$class[1] == "dynamic_survival_metric"
if (is_dyn) {
metrics_time <- eval_time[1]
} else {
metrics_time <- NULL
}

racing_obj_log <- function(analysis_metric, direction, control, metrics_time = NULL) {
cols <- tune::get_tune_colors()
if (control$verbose_elim) {
msg <-
paste(
"Racing will",
ifelse(analysis_max, "maximize", "minimize"),
"the",
analysis_metric,
"metric"
)
msg <- paste("Racing will", direction, "the", analysis_metric, "metric")

if (!is.null(metrics_time)) {
msg <- paste(msg, "at time", format(metrics_time, digits = 3))
Expand Down
23 changes: 14 additions & 9 deletions R/tune_race_anova.R
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ tune_race_anova.workflow <-

tune_race_anova_workflow <-
function(object, resamples, param_info = NULL, grid = 10, metrics = NULL,
control = control_race(), eval_time = NULL) {
control = control_race(), eval_time = NULL, call = caller_env()) {
rlang::check_installed("lme4")

tune::initialize_catalog(control = control)
Expand Down Expand Up @@ -244,14 +244,19 @@ tune_race_anova_workflow <-
)

param_names <- tune::.get_tune_parameter_names(res)

metrics <- tune::.get_tune_metrics(res)
metrics <- tune::check_metrics_arg(metrics, object, call = call)
opt_metric <- tune::first_metric(metrics)
opt_metric_name <- opt_metric$metric
maximize <- opt_metric$direction == "maximize"

racing_obj_log(res, metrics, control, eval_time)
eval_time <- tune::check_eval_time_arg(eval_time, metrics, call = call)
opt_metric_time <- tune::first_eval_time(metrics, opt_metric_name, eval_time)

analysis_metric <- names(attr(metrics, "metrics"))[1]
metrics_time <- eval_time[1]
racing_obj_log(opt_metric_name, opt_metric$direction, control, opt_metric_time)

filters_results <- test_parameters_gls(res, control$alpha, metrics_time)
filters_results <- test_parameters_gls(res, control$alpha, opt_metric_time)
n_grid <- nrow(filters_results)

log_final <- TRUE
Expand All @@ -267,11 +272,11 @@ tune_race_anova_workflow <-

if (nrow(new_grid) > 1) {
tmp_resamples <- restore_rset(resamples, rs)
log_racing(control, filters_results, res$splits, n_grid, analysis_metric)
log_racing(control, filters_results, res$splits, n_grid, opt_metric_name)
} else {
tmp_resamples <- restore_rset(resamples, rs:B)
if (log_final) {
log_racing(control, filters_results, res$splits, n_grid, analysis_metric)
log_racing(control, filters_results, res$splits, n_grid, opt_metric_name)
}
log_final <- FALSE
}
Expand All @@ -288,10 +293,10 @@ tune_race_anova_workflow <-
eval_time = eval_time
)

res <- restore_tune(res, tmp_res)
res <- restore_tune(res, tmp_res, opt_metric_time)

if (nrow(new_grid) > 1) {
filters_results <- test_parameters_gls(res, control$alpha, metrics_time)
filters_results <- test_parameters_gls(res, control$alpha, opt_metric_time)
if (sum(filters_results$pass) == 2 & num_ties >= control$num_ties) {
filters_results <- tie_breaker(res, control, eval_time = eval_time)
}
Expand Down
23 changes: 14 additions & 9 deletions R/tune_race_win_loss.R
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ tune_race_win_loss.workflow <- function(object, resamples, ..., param_info = NUL

tune_race_win_loss_workflow <-
function(object, resamples, param_info = NULL, grid = 10, metrics = NULL,
control = control_race(), eval_time = NULL) {
control = control_race(), eval_time = NULL, call = caller_env()) {
rlang::check_installed("BradleyTerry2")

B <- nrow(resamples)
Expand Down Expand Up @@ -239,12 +239,17 @@ tune_race_win_loss_workflow <-
param_names <- tune::.get_tune_parameter_names(res)

metrics <- tune::.get_tune_metrics(res)
racing_obj_log(res, metrics, control, eval_time)
metrics <- tune::check_metrics_arg(metrics, object, call = call)
opt_metric <- tune::first_metric(metrics)
opt_metric_name <- opt_metric$metric
maximize <- opt_metric$direction == "maximize"

analysis_metric <- names(attr(metrics, "metrics"))[1]
metrics_time <- eval_time[1]
eval_time <- tune::check_eval_time_arg(eval_time, metrics, call = call)
opt_metric_time <- tune::first_eval_time(metrics, opt_metric_name, eval_time)

filters_results <- test_parameters_bt(res, control$alpha, metrics_time)
racing_obj_log(opt_metric_name, opt_metric$direction, control, opt_metric_time)

filters_results <- test_parameters_bt(res, control$alpha, opt_metric_time)
n_grid <- nrow(filters_results)

log_final <- TRUE
Expand All @@ -260,11 +265,11 @@ tune_race_win_loss_workflow <-

if (nrow(new_grid) > 1) {
tmp_resamples <- restore_rset(resamples, rs)
log_racing(control, filters_results, res$splits, n_grid, analysis_metric)
log_racing(control, filters_results, res$splits, n_grid, opt_metric_name)
} else {
tmp_resamples <- restore_rset(resamples, rs:B)
if (log_final) {
log_racing(control, filters_results, res$splits, n_grid, analysis_metric)
log_racing(control, filters_results, res$splits, n_grid, opt_metric_name)
}
log_final <- FALSE
}
Expand All @@ -280,10 +285,10 @@ tune_race_win_loss_workflow <-
control = grid_control,
eval_time = eval_time
)
res <- restore_tune(res, tmp_res)
res <- restore_tune(res, tmp_res, opt_metric_time)

if (nrow(new_grid) > 1) {
filters_results <- test_parameters_bt(res, control$alpha, metrics_time)
filters_results <- test_parameters_bt(res, control$alpha, opt_metric_time)
if (sum(filters_results$pass) == 2 & num_ties >= control$num_ties) {
filters_results <- tie_breaker(res, control)
}
Expand Down
39 changes: 25 additions & 14 deletions R/tune_sim_anneal.R
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,8 @@ tune_sim_anneal.workflow <-

tune_sim_anneal_workflow <-
function(object, resamples, iter = 10, param_info = NULL, metrics = NULL,
initial = 5, control = control_sim_anneal(), eval_time = NULL) {
initial = 5, control = control_sim_anneal(), eval_time = NULL,
call = caller_env()) {
start_time <- proc.time()[3]
cols <- tune::get_tune_colors()

Expand All @@ -314,10 +315,13 @@ tune_sim_anneal_workflow <-
tune::check_rset(resamples)
rset_info <- tune::pull_rset_attributes(resamples)

metrics <- tune::check_metrics(metrics, object)
metrics_name <- names(attr(metrics, "metrics"))[1]
metrics_time <- eval_time[1]
maximize <- attr(attr(metrics, "metrics")[[1]], "direction") == "maximize"
metrics <- tune::check_metrics_arg(metrics, object, call = call)
opt_metric <- tune::first_metric(metrics)
opt_metric_name <- opt_metric$metric
maximize <- opt_metric$direction == "maximize"

eval_time <- tune::check_eval_time_arg(eval_time, metrics, call = call)
opt_metric_time <- tune::first_eval_time(metrics, opt_metric_name, eval_time)

if (is.null(param_info)) {
param_info <- extract_parameter_set_dials(object)
Expand Down Expand Up @@ -362,6 +366,7 @@ tune_sim_anneal_workflow <-
parameters = param_info,
metrics = metrics,
eval_time = eval_time,
eval_time_target = opt_metric_time,
outcomes = y_names,
rset_info = rset_info,
workflow = object
Expand All @@ -388,6 +393,7 @@ tune_sim_anneal_workflow <-
parameters = param_info,
metrics = metrics,
eval_time = eval_time,
eval_time_target = opt_metric_time,
outcomes = y_names,
rset_info = rset_info,
workflow = object
Expand All @@ -397,16 +403,20 @@ tune_sim_anneal_workflow <-
return(out)
})

cols <- tune::get_tune_colors()
if (control$verbose_iter) {
cli::cli_bullets(cols$message$info(paste("Optimizing", metrics_name)))
msg <- paste("Optimizing", opt_metric_name)
if (!is.null(opt_metric_time)) {
msg <- paste(msg, "at evaluation time", format(opt_metric_time, digits = 3))
}
cli::cli_bullets(msg)
}


## -----------------------------------------------------------------------------

result_history <- initialize_history(unsummarized, metrics_time)
result_history <- initialize_history(unsummarized, opt_metric_time)
best_param <-
tune::select_best(unsummarized, metric = metrics_name, eval_time = metrics_time) %>%
tune::select_best(unsummarized, metric = opt_metric_name, eval_time = opt_metric_time) %>%
dplyr::mutate(.parent = NA_character_)
grid_history <- best_param
current_param <- best_param
Expand All @@ -426,7 +436,7 @@ tune_sim_anneal_workflow <-
x = result_history,
max_iter = iter,
maximize = maximize,
metric = metrics_name
metric = opt_metric_name
)

for (i in (existing_iter + 1):iter) {
Expand All @@ -451,17 +461,17 @@ tune_sim_anneal_workflow <-
grid = new_grid %>% dplyr::select(-.config, -.parent),
metrics = metrics,
control = control_init,
eval_time = metrics_time
eval_time = eval_time
) %>%
dplyr::mutate(.iter = i) %>%
update_config(config = paste0("Iter", i), save_pred = control$save_pred)

result_history <-
result_history %>%
update_history(res, i, eval_time = metrics_time) %>%
update_history(res, i, eval_time = opt_metric_time) %>%
sa_decide(
parent = new_grid$.parent,
metric = metrics_name,
metric = opt_metric_name,
maximize = maximize,
coef = control$cooling_coef
)
Expand Down Expand Up @@ -508,6 +518,7 @@ tune_sim_anneal_workflow <-
parameters = param_info,
metrics = metrics,
eval_time = eval_time,
eval_time_target = opt_metric_time,
outcomes = y_names,
rset_info = rset_info,
workflow = object
Expand All @@ -520,7 +531,7 @@ tune_sim_anneal_workflow <-
x = result_history,
max_iter = iter,
maximize = maximize,
metric = metrics_name
metric = opt_metric_name
)

if (count_improve >= control$no_improve) {
Expand Down
8 changes: 8 additions & 0 deletions tests/testthat/_snaps/win-loss-overall.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,11 @@
i Fold2, Repeat2: 0 eliminated; 5 candidates remain.
i Fold3, Repeat2: 0 eliminated; 5 candidates remain.

# one player is really bad

Code
best_res <- show_best(tuning_results)
Condition
Warning in `show_best()`:
No value of `metric` was given; "roc_auc" will be used.

1 change: 1 addition & 0 deletions tests/testthat/test-anova-overall.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ test_that("formula interface", {
expect_true(nrow(collect_metrics(res)) < nrow(grid_mod) * 2)
expect_equal(res, .Last.tune.result)
expect_null(.get_tune_eval_times(res))
expect_null(.get_tune_eval_time_target(res))
})

# ------------------------------------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion tests/testthat/test-sa-decision.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ test_that("simulated annealing decisions", {
outcomes = cart_outcomes,
metrics = cart_metrics,
eval_time = NULL,
eval_time_target = NULL,
rset_info = cart_rset_info
)
iter_new_hist <- finetune:::update_history(iter_hist, iter_res, iter_val)
iter_new_hist <- finetune:::update_history(iter_hist, iter_res, iter_val, NULL)
iter_new_hist$random[1:nrow(iter_new_hist)] <- cart_history$random[1:nrow(iter_new_hist)]

expect_equal(
Expand Down
1 change: 1 addition & 0 deletions tests/testthat/test-sa-overall.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ test_that("formula interface", {
expect_true(nrow(collect_metrics(res)) == 6)
expect_equal(res, .Last.tune.result)
expect_null(.get_tune_eval_times(res))
expect_null(.get_tune_eval_time_target(res))
})

# ------------------------------------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions tests/testthat/test-win-loss-filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ test_that("top-level win/loss filter interfaces", {
expect_true(inherits(wl_mod, "tune_results"))
expect_true(tibble::is_tibble((wl_mod)))
expect_null(.get_tune_eval_times(wl_mod))
expect_null(.get_tune_eval_time_target(wl_mod))

expect_silent({
set.seed(129)
Expand Down
Loading

0 comments on commit a57178b

Please sign in to comment.