Skip to content

Commit

Permalink
add partykit decision tree classification support
Browse files Browse the repository at this point in the history
  • Loading branch information
EmilHvitfeldt committed Dec 20, 2024
1 parent 6255be8 commit f377ac8
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 1 deletion.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Imports:
rlang
Suggests:
arrow,
bonsai,
DBI,
dbplyr,
dtplyr,
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Generated by roxygen2: do not edit by hand

S3method(augment,orbital_class)
S3method(orbital,constparty)
S3method(orbital,default)
S3method(orbital,glm)
S3method(orbital,last_fit)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

* `orbital()` now works with `boost_tree(engine = "xgboost")` models for class prediction and probability predictions. (#71)

* `orbital()` now works with `decision_tree(engine = "partykit")` models for class prediction and probability predictions. (#77)

# orbital 0.2.0

Expand Down
30 changes: 30 additions & 0 deletions R/model-partykit.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#' @export
orbital.constparty <- function(
x,
...,
mode = c("classification", "regression"),
type = NULL,
lvl = NULL
) {
mode <- rlang::arg_match(mode)
type <- default_type(type)

if (mode == "classification") {
res <- character()
if ("class" %in% type) {
eq <- tidypredict::tidypredict_fit(x)
eq <- deparse1(eq)
eq <- namespace_case_when(eq)
res <- c(res, orbital_tmp_class_name = eq)
}
if ("prob" %in% type) {
eqs <- tidypredict::.extract_partykit_classprob(x)
eqs <- namespace_case_when(eqs)
names(eqs) <- paste0("orbital_tmp_prob_name", seq_along(lvl))
res <- c(res, eqs)
}
} else if (mode == "regression") {
res <- tidypredict::tidypredict_fit(x)
}
res
}
94 changes: 94 additions & 0 deletions tests/testthat/test-model-partykit.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
test_that("decision_tree(partykit) works with type = class", {
skip_if_not_installed("parsnip")
skip_if_not_installed("tidypredict")
skip_if_not_installed("bonsai")
library(bonsai)

mtcars$vs <- factor(mtcars$vs)

lr_spec <- parsnip::decision_tree("classification", "partykit")

lr_fit <- parsnip::fit(lr_spec, vs ~ disp + mpg + hp, mtcars)

orb_obj <- orbital(lr_fit, type = "class")

preds <- predict(orb_obj, mtcars)
exps <- predict(lr_fit, mtcars)

expect_named(preds, ".pred_class")
expect_type(preds$.pred_class, "character")

expect_identical(
preds$.pred_class,
as.character(exps$.pred_class)
)
})

test_that("decision_tree(partykit) works with type = prob", {
skip_if_not_installed("parsnip")
skip_if_not_installed("tidypredict")
skip_if_not_installed("bonsai")
library(bonsai)

mtcars$vs <- factor(mtcars$vs)

lr_spec <- parsnip::decision_tree("classification", "partykit")

lr_fit <- parsnip::fit(lr_spec, vs ~ disp + mpg + hp, mtcars)

orb_obj <- orbital(lr_fit, type = "prob")

preds <- predict(orb_obj, mtcars)
exps <- predict(lr_fit, mtcars, type = "prob")

expect_named(preds, c(".pred_0", ".pred_1"))
expect_type(preds$.pred_0, "double")
expect_type(preds$.pred_1, "double")

exps <- as.data.frame(exps)

rownames(preds) <- NULL
rownames(exps) <- NULL

expect_equal(
preds,
exps
)
})

test_that("decision_tree(partykit) works with type = c(class, prob)", {
skip_if_not_installed("parsnip")
skip_if_not_installed("tidypredict")
skip_if_not_installed("bonsai")
library(bonsai)

mtcars$vs <- factor(mtcars$vs)

lr_spec <- parsnip::decision_tree("classification", "partykit")

lr_fit <- parsnip::fit(lr_spec, vs ~ disp + mpg + hp, mtcars)

orb_obj <- orbital(lr_fit, type = c("class", "prob"))

preds <- predict(orb_obj, mtcars)
exps <- dplyr::bind_cols(
predict(lr_fit, mtcars, type = c("class")),
predict(lr_fit, mtcars, type = c("prob"))
)

expect_named(preds, c(".pred_class", ".pred_0", ".pred_1"))
expect_type(preds$.pred_class, "character")
expect_type(preds$.pred_0, "double")
expect_type(preds$.pred_1, "double")

exps <- as.data.frame(exps)
exps$.pred_class <- as.character(exps$.pred_class)

rownames(preds) <- NULL
rownames(exps) <- NULL

expect_equal(
preds,
exps
)
})
2 changes: 1 addition & 1 deletion vignettes/supported-models.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ tibble::tribble(
~parsnip, ~engine, ~numeric, ~class, ~prob,
"`boost_tree()`", "`\"xgboost\"`", "✅", "✅", "✅",
"`cubist_rules()`", "`\"Cubist\"`", "✅", "❌", "❌",
"`decision_tree()`", "`\"partykit\"`", "✅", "", "",
"`decision_tree()`", "`\"partykit\"`", "✅", "", "",
"`linear_reg()`", "`\"lm\"`", "✅", "❌", "❌",
"`linear_reg()`", "`\"glmnet\"`", "⚪", "❌", "❌",
"`logistic_reg()`", "`\"glm\"`", "❌", "✅", "✅",
Expand Down

0 comments on commit f377ac8

Please sign in to comment.