Skip to content

Commit

Permalink
use namespace_case_when() more generally
Browse files Browse the repository at this point in the history
to close #53
  • Loading branch information
EmilHvitfeldt committed Dec 20, 2024
1 parent 8c74a57 commit 0bbaf4c
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 12 deletions.
6 changes: 2 additions & 4 deletions R/model-partykit.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@ orbital.constparty <- function(
if (mode == "classification") {
res <- character()
if ("class" %in% type) {
eq <- tidypredict::tidypredict_fit(x)
eq <- deparse1(eq)
eq <- namespace_case_when(eq)
eq <- tidypredict::tidypredict_fit(x)
eq <- deparse1(eq)
res <- c(res, orbital_tmp_class_name = eq)
}
if ("prob" %in% type) {
eqs <- tidypredict::.extract_partykit_classprob(x)
eqs <- namespace_case_when(eqs)
names(eqs) <- paste0("orbital_tmp_prob_name", seq_along(lvl))
res <- c(res, eqs)
}
Expand Down
8 changes: 0 additions & 8 deletions R/model-xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ xgboost_multisoft <- function(x, type, lvl) {
trees_split <- split(trees, rep(seq_along(lvl), x$niter))
trees_split <- lapply(trees_split, collapse_stumps)
trees_split <- vapply(trees_split, paste, character(1), collapse = " + ")
trees_split <- namespace_case_when(trees_split)

res <- stats::setNames(trees_split, lvl)

Expand Down Expand Up @@ -75,7 +74,6 @@ xgboost_logistic <- function(x, type, lvl) {
eq <- tidypredict::tidypredict_fit(x)

eq <- deparse1(eq)
eq <- namespace_case_when(eq)

res <- NULL
if ("class" %in% type) {
Expand All @@ -98,12 +96,6 @@ xgboost_logistic <- function(x, type, lvl) {
res
}

namespace_case_when <- function(x) {
x <- gsub("dplyr::case_when", "case_when", x)
x <- gsub("case_when", "dplyr::case_when", x)
x
}

softmax <- function(lvl) {
res <- character(0)

Expand Down
1 change: 1 addition & 0 deletions R/parsnip.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ orbital.model_fit <- function(x, ..., prefix = ".pred", type = NULL) {
res <- deparse1(res)
}

res <- namespace_case_when(res)
res <- set_pred_names(res, x, mode, type, prefix)

new_orbital_class(res)
Expand Down
7 changes: 7 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
namespace_case_when <- function(x) {
names <- names(x)
x <- gsub("dplyr::case_when", "case_when", x)
x <- gsub("case_when", "dplyr::case_when", x)
names(x) <- names
x
}
2 changes: 2 additions & 0 deletions R/workflows.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,7 @@ orbital.workflow <- function(x, ..., prefix = ".pred", type = NULL) {
attr(out, "pred_names") <- pred_names
}

out <- namespace_case_when(out)

new_orbital_class(out)
}

0 comments on commit 0bbaf4c

Please sign in to comment.