From dc903671955723ffd47ed99a5d04912374801c8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Sun, 21 Jan 2024 13:20:09 -0500 Subject: [PATCH 1/6] Create standalone-input-names.R --- R/standalone-input-names.R | 64 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 R/standalone-input-names.R diff --git a/R/standalone-input-names.R b/R/standalone-input-names.R new file mode 100644 index 00000000..b19488cd --- /dev/null +++ b/R/standalone-input-names.R @@ -0,0 +1,64 @@ +# --- +# repo: tidymodels/workflows +# file: standalone-input-names.R +# last-updated: 2024-01-21 +# license: https://unlicense.org +# --- + +# secret gist at: https://gist.github.com/topepo/17d51cafcd0ac8dff0552198d6aeadbf + +# This file provides a portable set of helper functions for determining the +# names of the predictor columns used as inputs into a workflow. + +# ## Changelog +# 2024-01-21 +# * First version + +# ------------------------------------------------------------------------------ + +check_workflow_fit <- function(x) { + if (!x$trained) { + stop("The workflow should be trainined.") + } + invisible(NULL) +} + +check_recipe_fit <- function(x) { + is_trained <- vapply(x$steps, function(x) x$trained, logical(1)) + if (!all(is_trained)) { + stop("All recipe steps should be trainined.") + } + invisible(NULL) +} + +blueprint_ptype <- function(x) { + names(x$pre$mold$blueprint$ptypes$predictors) +} + +.get_input_predictors_workflow <- function(x, ...) { + check_workflow_fit(x) + # We can get the columns that are inputs to the recipe but some of these may + # not be predictors. We'll interrogate the recipe and pull out the current + # predictor names from the original input + if ("recipe" %in% names(x$pre$actions)) { + mold <- x$pre$mold + rec <- mold$blueprint$recipe + res <- .get_input_predictors_recipe(rec) + } else { + res <- blueprint_ptype(x) + } + sort(unique(res)) +} + +is_predictor_role <- function(x) { + vapply(x$role, function(x) any(x == "predictor"), logical(1)) +} + +.get_input_predictors_recipe <- function(x, ...) { + check_recipe_fit(x) + var_info <- x$last_term_info + + keep_rows <- var_info$source == "original" & is_predictor_role(var_info) + var_info <- var_info[keep_rows,] + var_info$variable +} From 07f30f816a37676b66b6de761a54d7c3c1587058 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Sun, 21 Jan 2024 13:37:57 -0500 Subject: [PATCH 2/6] unit tests --- NEWS.md | 3 + tests/testthat/_snaps/input-names.md | 32 ++++++++++ tests/testthat/test-input-names.R | 88 ++++++++++++++++++++++++++++ 3 files changed, 123 insertions(+) create mode 100644 tests/testthat/_snaps/input-names.md create mode 100644 tests/testthat/test-input-names.R diff --git a/NEWS.md b/NEWS.md index f5154a03..119647fe 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # workflows (development version) +* Added a standalone file `standalone-input-names.R` with APIs for returning the +names of the predictors in the original data given to `fit()`. + * Each of the `pull_*()` functions soft-deprecated in workflows v0.2.3 now warn on every usage. * `add_recipe()` will now error informatively when supplied a trained recipe (#179). diff --git a/tests/testthat/_snaps/input-names.md b/tests/testthat/_snaps/input-names.md new file mode 100644 index 00000000..b087f326 --- /dev/null +++ b/tests/testthat/_snaps/input-names.md @@ -0,0 +1,32 @@ +# get recipe input column names + + Code + workflows:::.get_input_predictors_workflow(workflow) + Condition + Error in `check_workflow_fit()`: + ! The workflow should be trainined. + +--- + + Code + workflows:::.get_input_predictors_recipe(rec_with_id) + Condition + Error in `check_recipe_fit()`: + ! All recipe steps should be trainined. + +# get formula input column names + + Code + workflows:::.get_input_predictors_workflow(workflow) + Condition + Error in `check_workflow_fit()`: + ! The workflow should be trainined. + +# get predictor input column names + + Code + workflows:::.get_input_predictors_workflow(workflow) + Condition + Error in `check_workflow_fit()`: + ! The workflow should be trainined. + diff --git a/tests/testthat/test-input-names.R b/tests/testthat/test-input-names.R new file mode 100644 index 00000000..2fcc803f --- /dev/null +++ b/tests/testthat/test-input-names.R @@ -0,0 +1,88 @@ +test_that("get recipe input column names", { + skip_if_not_installed("modeldata") + skip_if_not_installed("recipes") + + library(recipes) + + data(cells, package = "modeldata") + + cells <- cells[, 1:10] + pred_names <- sort(names(cells)[3:10]) + + rec_with_id <- + recipes::recipe(class ~ ., cells) %>% + update_role(case, new_role = "destination") %>% + step_rm(angle_ch_1) %>% + step_pca(all_predictors()) + + workflow <- workflow() + workflow <- add_recipe(workflow, rec_with_id) + workflow <- add_model(workflow, parsnip::logistic_reg()) + workflow_fit <- fit(workflow, cells) + + expect_snapshot( + workflows:::.get_input_predictors_workflow(workflow), + error = TRUE + ) + expect_equal( + workflows:::.get_input_predictors_workflow(workflow_fit), + pred_names + ) + expect_snapshot( + workflows:::.get_input_predictors_recipe(rec_with_id), + error = TRUE + ) + +}) + +test_that("get formula input column names", { + skip_if_not_installed("modeldata") + + data(Chicago, package = "modeldata") + + Chicago <- Chicago[, c("ridership", "date", "Austin")] + pred_names <- sort(c("date", "Austin")) + + workflow <- workflow() + workflow <- add_formula(workflow, ridership ~ .) + workflow <- add_model(workflow, parsnip::linear_reg()) + workflow_fit <- fit(workflow, Chicago) + + expect_snapshot( + workflows:::.get_input_predictors_workflow(workflow), + error = TRUE + ) + expect_equal( + workflows:::.get_input_predictors_workflow(workflow_fit), + pred_names + ) + +}) + + +test_that("get predictor input column names", { + skip_if_not_installed("modeldata") + + data(Chicago, package = "modeldata") + + Chicago <- Chicago[, c("ridership", "date", "Austin")] + pred_names <- sort(c("date", "Austin")) + + workflow <- workflow() + workflow <- + add_variables(workflow, + outcomes = c(ridership), + predictors = c(tidyselect::everything())) + workflow <- add_model(workflow, parsnip::linear_reg()) + workflow_fit <- fit(workflow, Chicago) + + expect_snapshot( + workflows:::.get_input_predictors_workflow(workflow), + error = TRUE + ) + expect_equal( + workflows:::.get_input_predictors_workflow(workflow_fit), + pred_names + ) + +}) From c93a3beed1d6d604b9fce33021b04e6385dd42ec Mon Sep 17 00:00:00 2001 From: Max Kuhn Date: Thu, 25 Jan 2024 19:01:19 -0500 Subject: [PATCH 3/6] Apply suggestions from code review Co-authored-by: Simon P. Couch --- R/standalone-input-names.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/standalone-input-names.R b/R/standalone-input-names.R index b19488cd..00ce092b 100644 --- a/R/standalone-input-names.R +++ b/R/standalone-input-names.R @@ -18,7 +18,7 @@ check_workflow_fit <- function(x) { if (!x$trained) { - stop("The workflow should be trainined.") + stop("The workflow should be trained.") } invisible(NULL) } @@ -26,7 +26,7 @@ check_workflow_fit <- function(x) { check_recipe_fit <- function(x) { is_trained <- vapply(x$steps, function(x) x$trained, logical(1)) if (!all(is_trained)) { - stop("All recipe steps should be trainined.") + stop("All recipe steps should be trained.") } invisible(NULL) } From 34211475fb0c9c92d5ac475ca092c93cbbf0055c Mon Sep 17 00:00:00 2001 From: topepo Date: Mon, 29 Jan 2024 13:29:07 -0500 Subject: [PATCH 4/6] changes based on reviewer feedback --- R/standalone-input-names.R | 77 +++++++++++++++++----------- tests/testthat/_snaps/input-names.md | 16 +++--- 2 files changed, 56 insertions(+), 37 deletions(-) diff --git a/R/standalone-input-names.R b/R/standalone-input-names.R index 00ce092b..437f12e5 100644 --- a/R/standalone-input-names.R +++ b/R/standalone-input-names.R @@ -1,42 +1,32 @@ # --- # repo: tidymodels/workflows # file: standalone-input-names.R -# last-updated: 2024-01-21 +# last-updated: 2024-01-291 # license: https://unlicense.org +# requires: cli, rlang # --- -# secret gist at: https://gist.github.com/topepo/17d51cafcd0ac8dff0552198d6aeadbf - # This file provides a portable set of helper functions for determining the # names of the predictor columns used as inputs into a workflow. # ## Changelog # 2024-01-21 # * First version +# 2024-01-29 +# * Changes after PR review -# ------------------------------------------------------------------------------ - -check_workflow_fit <- function(x) { - if (!x$trained) { - stop("The workflow should be trained.") - } - invisible(NULL) -} +# nocov start -check_recipe_fit <- function(x) { - is_trained <- vapply(x$steps, function(x) x$trained, logical(1)) - if (!all(is_trained)) { - stop("All recipe steps should be trained.") - } - invisible(NULL) -} +# ------------------------------------------------------------------------------ +# Primary functions -blueprint_ptype <- function(x) { - names(x$pre$mold$blueprint$ptypes$predictors) -} +# @param x A _fitted_ workflow or recipe. +# @param call An environment indicating where the top-level function was invoked +# to print out better errors. +# @return A character vector of sorted columns names. -.get_input_predictors_workflow <- function(x, ...) { - check_workflow_fit(x) +.get_input_predictors_workflow <- function(x, ..., call = rlang::current_env()) { + check_workflow_fit(x, call = call) # We can get the columns that are inputs to the recipe but some of these may # not be predictors. We'll interrogate the recipe and pull out the current # predictor names from the original input @@ -50,15 +40,44 @@ blueprint_ptype <- function(x) { sort(unique(res)) } -is_predictor_role <- function(x) { - vapply(x$role, function(x) any(x == "predictor"), logical(1)) -} - -.get_input_predictors_recipe <- function(x, ...) { - check_recipe_fit(x) +.get_input_predictors_recipe <- function(x, ..., call = rlang::current_env()) { + check_recipe_fit(x, call = call) var_info <- x$last_term_info keep_rows <- var_info$source == "original" & is_predictor_role(var_info) var_info <- var_info[keep_rows,] var_info$variable } + +.get_input_outcome_workflow <- function(x) { + check_workflow_fit(x) + names(x$pre$mold$blueprint$ptypes$outcomes) +} + +# ------------------------------------------------------------------------------ +# Helper functions + +check_workflow_fit <- function(x, call) { + if (!x$trained) { + cli::cli_abort("The workflow should be trained.", call = call) + } + invisible(NULL) +} + +check_recipe_fit <- function(x, call) { + is_trained <- vapply(x$steps, function(x) x$trained, logical(1)) + if (!all(is_trained)) { + cli::cli_abort("All recipe steps should be trained.", call = call) + } + invisible(NULL) +} + +blueprint_ptype <- function(x) { + names(x$pre$mold$blueprint$ptypes$predictors) +} + +is_predictor_role <- function(x) { + vapply(x$role, function(x) any(x == "predictor"), logical(1)) +} + +# nocov end diff --git a/tests/testthat/_snaps/input-names.md b/tests/testthat/_snaps/input-names.md index b087f326..dce747e1 100644 --- a/tests/testthat/_snaps/input-names.md +++ b/tests/testthat/_snaps/input-names.md @@ -3,30 +3,30 @@ Code workflows:::.get_input_predictors_workflow(workflow) Condition - Error in `check_workflow_fit()`: - ! The workflow should be trainined. + Error in `workflows:::.get_input_predictors_workflow()`: + ! The workflow should be trained. --- Code workflows:::.get_input_predictors_recipe(rec_with_id) Condition - Error in `check_recipe_fit()`: - ! All recipe steps should be trainined. + Error in `workflows:::.get_input_predictors_recipe()`: + ! All recipe steps should be trained. # get formula input column names Code workflows:::.get_input_predictors_workflow(workflow) Condition - Error in `check_workflow_fit()`: - ! The workflow should be trainined. + Error in `workflows:::.get_input_predictors_workflow()`: + ! The workflow should be trained. # get predictor input column names Code workflows:::.get_input_predictors_workflow(workflow) Condition - Error in `check_workflow_fit()`: - ! The workflow should be trainined. + Error in `workflows:::.get_input_predictors_workflow()`: + ! The workflow should be trained. From e2d83a31e964200751e852b05a265e371303ecc0 Mon Sep 17 00:00:00 2001 From: topepo Date: Mon, 29 Jan 2024 13:40:19 -0500 Subject: [PATCH 5/6] renamed file --- .../testthat/_snaps/{input-names.md => standalone-input-names.md} | 0 .../{test-input-names.R => test-standalone-input-names.R} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tests/testthat/_snaps/{input-names.md => standalone-input-names.md} (100%) rename tests/testthat/{test-input-names.R => test-standalone-input-names.R} (100%) diff --git a/tests/testthat/_snaps/input-names.md b/tests/testthat/_snaps/standalone-input-names.md similarity index 100% rename from tests/testthat/_snaps/input-names.md rename to tests/testthat/_snaps/standalone-input-names.md diff --git a/tests/testthat/test-input-names.R b/tests/testthat/test-standalone-input-names.R similarity index 100% rename from tests/testthat/test-input-names.R rename to tests/testthat/test-standalone-input-names.R From 1180fde934fa4e6bcce4eeb491ce9476d0481879 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Mon, 29 Jan 2024 15:22:11 -0500 Subject: [PATCH 6/6] fix/change dates --- R/standalone-input-names.R | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/R/standalone-input-names.R b/R/standalone-input-names.R index 437f12e5..71e33e75 100644 --- a/R/standalone-input-names.R +++ b/R/standalone-input-names.R @@ -1,7 +1,7 @@ # --- # repo: tidymodels/workflows # file: standalone-input-names.R -# last-updated: 2024-01-291 +# last-updated: 2024-01-21 # license: https://unlicense.org # requires: cli, rlang # --- @@ -12,8 +12,6 @@ # ## Changelog # 2024-01-21 # * First version -# 2024-01-29 -# * Changes after PR review # nocov start