diff --git a/NEWS.md b/NEWS.md index d184da56a..db1a11142 100644 --- a/NEWS.md +++ b/NEWS.md @@ -33,6 +33,8 @@ * Fixed bug where some models fit using `fit_xy()` couldn't predict (#1166). +* Fixed bug related to using local (non-package) models (#1229) + * `tunable()` now references a dials object for the `mixture` parameter (#1236) ## Breaking Change diff --git a/R/misc.R b/R/misc.R index 4a3537f89..f02bf9a54 100644 --- a/R/misc.R +++ b/R/misc.R @@ -241,13 +241,15 @@ prompt_missing_implementation <- function(spec, #' @keywords internal #' @export show_call <- function(object) { - object$method$fit$args <- - map(object$method$fit$args, convert_arg) + object$method$fit$args <- map(object$method$fit$args, convert_arg) - call2(object$method$fit$func["fun"], - !!!object$method$fit$args, - .ns = object$method$fit$func["pkg"] - ) + fn_info <- as.list(object$method$fit$func) + if (!any(names(fn_info) == "pkg")) { + res <- call2(fn_info$fun, !!!object$method$fit$args) + } else { + res <- call2(fn_info$fun, !!!object$method$fit$args, .ns = fn_info$pkg) + } + res } convert_arg <- function(x) { @@ -301,8 +303,8 @@ check_args.default <- function(object, call = rlang::caller_env()) { # ------------------------------------------------------------------------------ -# copied form recipes - +# copied from recipes +# nocov start names0 <- function(num, prefix = "x", call = rlang::caller_env()) { if (num < 1) { cli::cli_abort("{.arg num} should be > 0.", call = call) @@ -311,7 +313,7 @@ names0 <- function(num, prefix = "x", call = rlang::caller_env()) { ind <- gsub(" ", "0", ind) paste0(prefix, ind) } - +# nocov end # ------------------------------------------------------------------------------ diff --git a/parsnip.Rproj b/parsnip.Rproj index 060c78308..92c87e240 100644 --- a/parsnip.Rproj +++ b/parsnip.Rproj @@ -1,6 +1,7 @@ Version: 1.0 ProjectId: 7f6c9ff5-6b9a-4235-8666-12db5ef65d49 + RestoreWorkspace: No SaveWorkspace: No AlwaysSaveHistory: Default diff --git a/tests/testthat/_snaps/misc.md b/tests/testthat/_snaps/misc.md index b221b1dde..6e43ec0d7 100644 --- a/tests/testthat/_snaps/misc.md +++ b/tests/testthat/_snaps/misc.md @@ -243,3 +243,15 @@ Error in `.get_prediction_column_names()`: ! Prediction information could not be found for this `linear_reg()` with engine "lm" and mode "Depeche". Does a parsnip extension package need to be loaded? +# register local models + + Code + my_model() %>% translate("my_engine") + Output + my model Model Specification (regression) + + Computational engine: my_engine + + Model fit template: + my_model_fun(formula = missing_arg(), data = missing_arg()) + diff --git a/tests/testthat/test-decision_tree.R b/tests/testthat/test-decision_tree.R index 7d32fdf0a..faecf59f9 100644 --- a/tests/testthat/test-decision_tree.R +++ b/tests/testthat/test-decision_tree.R @@ -25,6 +25,16 @@ test_that('bad input', { ) }) +test_that('rpart_train is stop-deprecated when it ought to be (#1044)', { + skip_on_cran() + + # once this test fails, transition `rpart_train()` to `deprecate_stop()` + # and transition this test to fail if `rpart_train()` still exists after a year. + if (Sys.Date() > "2025-02-01") { + expect_snapshot(error = TRUE, rpart_train(mpg ~ ., mtcars)) + } +}) + # ------------------------------------------------------------------------------ test_that('argument checks for data dimensions', { diff --git a/tests/testthat/test-misc.R b/tests/testthat/test-misc.R index 901c92748..b9566fb0f 100644 --- a/tests/testthat/test-misc.R +++ b/tests/testthat/test-misc.R @@ -299,3 +299,44 @@ test_that('obtaining prediction columns', { ) }) + + +# ------------------------------------------------------------------------------ + +# https://github.com/tidymodels/parsnip/issues/1229 +test_that('register local models', { + set_new_model("my_model") + set_model_mode(model = "my_model", mode = "regression") + set_model_engine( + "my_model", + mode = "regression", + eng = "my_engine" + ) + + my_model <- + function(mode = "regression") { + new_model_spec( + "my_model", + args = list(), + eng_args = NULL, + mode = mode, + method = NULL, + engine = NULL + ) + } + + set_fit( + model = "my_model", + eng = "my_engine", + mode = "regression", + value = list( + interface = "matrix", + protect = c("formula", "data"), + func = c(fun = "my_model_fun"), + defaults = list() + ) + ) + + expect_snapshot(my_model() %>% translate("my_engine")) +}) +