diff --git a/R/model-xgboost.R b/R/model-xgboost.R index b8c51d6..60f74a5 100644 --- a/R/model-xgboost.R +++ b/R/model-xgboost.R @@ -36,6 +36,7 @@ xgboost_multisoft <- function(x, type, lvl) { trees <- tidypredict::.extract_xgb_trees(x) trees_split <- split(trees, rep(seq_along(lvl), x$niter)) + trees_split <- lapply(trees_split, collapse_stumps) trees_split <- vapply(trees_split, paste, character(1), collapse = " + ") trees_split <- namespace_case_when(trees_split) @@ -57,6 +58,19 @@ xgboost_multisoft <- function(x, type, lvl) { res } +collapse_stumps <- function(x) { + stump_ind <- lengths(x) == 2 + + stumps <- x[stump_ind] + trees <- x[!stump_ind] + + stump_values <- lapply(stumps, function(x) eval(x[[2]][[3]])) + stump_values <- unlist(stump_values) + stump_values <- sum(stump_values) + + c(stump_values, trees) +} + xgboost_logistic <- function(x, type, lvl) { eq <- tidypredict::tidypredict_fit(x)