From f377ac8ded4a3203721930e72e9d8dd6b084485d Mon Sep 17 00:00:00 2001 From: Emil Hvitfeldt Date: Thu, 19 Dec 2024 17:19:03 -0800 Subject: [PATCH] add partykit decision tree classification support --- DESCRIPTION | 1 + NAMESPACE | 1 + NEWS.md | 1 + R/model-partykit.R | 30 +++++++++ tests/testthat/test-model-partykit.R | 94 ++++++++++++++++++++++++++++ vignettes/supported-models.Rmd | 2 +- 6 files changed, 128 insertions(+), 1 deletion(-) create mode 100644 R/model-partykit.R create mode 100644 tests/testthat/test-model-partykit.R diff --git a/DESCRIPTION b/DESCRIPTION index c92350b..85fb5bb 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -19,6 +19,7 @@ Imports: rlang Suggests: arrow, + bonsai, DBI, dbplyr, dtplyr, diff --git a/NAMESPACE b/NAMESPACE index 7b82ceb..8241a52 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,6 +1,7 @@ # Generated by roxygen2: do not edit by hand S3method(augment,orbital_class) +S3method(orbital,constparty) S3method(orbital,default) S3method(orbital,glm) S3method(orbital,last_fit) diff --git a/NEWS.md b/NEWS.md index 715e31d..98fd1da 100644 --- a/NEWS.md +++ b/NEWS.md @@ -10,6 +10,7 @@ * `orbital()` now works with `boost_tree(engine = "xgboost")` models for class prediction and probability predictions. (#71) +* `orbital()` now works with `decision_tree(engine = "partykit")` models for class prediction and probability predictions. (#77) # orbital 0.2.0 diff --git a/R/model-partykit.R b/R/model-partykit.R new file mode 100644 index 0000000..e595056 --- /dev/null +++ b/R/model-partykit.R @@ -0,0 +1,30 @@ +#' @export +orbital.constparty <- function( + x, + ..., + mode = c("classification", "regression"), + type = NULL, + lvl = NULL +) { + mode <- rlang::arg_match(mode) + type <- default_type(type) + + if (mode == "classification") { + res <- character() + if ("class" %in% type) { + eq <- tidypredict::tidypredict_fit(x) + eq <- deparse1(eq) + eq <- namespace_case_when(eq) + res <- c(res, orbital_tmp_class_name = eq) + } + if ("prob" %in% type) { + eqs <- tidypredict::.extract_partykit_classprob(x) + eqs <- namespace_case_when(eqs) + names(eqs) <- paste0("orbital_tmp_prob_name", seq_along(lvl)) + res <- c(res, eqs) + } + } else if (mode == "regression") { + res <- tidypredict::tidypredict_fit(x) + } + res +} diff --git a/tests/testthat/test-model-partykit.R b/tests/testthat/test-model-partykit.R new file mode 100644 index 0000000..8c43c05 --- /dev/null +++ b/tests/testthat/test-model-partykit.R @@ -0,0 +1,94 @@ +test_that("decision_tree(partykit) works with type = class", { + skip_if_not_installed("parsnip") + skip_if_not_installed("tidypredict") + skip_if_not_installed("bonsai") + library(bonsai) + + mtcars$vs <- factor(mtcars$vs) + + lr_spec <- parsnip::decision_tree("classification", "partykit") + + lr_fit <- parsnip::fit(lr_spec, vs ~ disp + mpg + hp, mtcars) + + orb_obj <- orbital(lr_fit, type = "class") + + preds <- predict(orb_obj, mtcars) + exps <- predict(lr_fit, mtcars) + + expect_named(preds, ".pred_class") + expect_type(preds$.pred_class, "character") + + expect_identical( + preds$.pred_class, + as.character(exps$.pred_class) + ) +}) + +test_that("decision_tree(partykit) works with type = prob", { + skip_if_not_installed("parsnip") + skip_if_not_installed("tidypredict") + skip_if_not_installed("bonsai") + library(bonsai) + + mtcars$vs <- factor(mtcars$vs) + + lr_spec <- parsnip::decision_tree("classification", "partykit") + + lr_fit <- parsnip::fit(lr_spec, vs ~ disp + mpg + hp, mtcars) + + orb_obj <- orbital(lr_fit, type = "prob") + + preds <- predict(orb_obj, mtcars) + exps <- predict(lr_fit, mtcars, type = "prob") + + expect_named(preds, c(".pred_0", ".pred_1")) + expect_type(preds$.pred_0, "double") + expect_type(preds$.pred_1, "double") + + exps <- as.data.frame(exps) + + rownames(preds) <- NULL + rownames(exps) <- NULL + + expect_equal( + preds, + exps + ) +}) + +test_that("decision_tree(partykit) works with type = c(class, prob)", { + skip_if_not_installed("parsnip") + skip_if_not_installed("tidypredict") + skip_if_not_installed("bonsai") + library(bonsai) + + mtcars$vs <- factor(mtcars$vs) + + lr_spec <- parsnip::decision_tree("classification", "partykit") + + lr_fit <- parsnip::fit(lr_spec, vs ~ disp + mpg + hp, mtcars) + + orb_obj <- orbital(lr_fit, type = c("class", "prob")) + + preds <- predict(orb_obj, mtcars) + exps <- dplyr::bind_cols( + predict(lr_fit, mtcars, type = c("class")), + predict(lr_fit, mtcars, type = c("prob")) + ) + + expect_named(preds, c(".pred_class", ".pred_0", ".pred_1")) + expect_type(preds$.pred_class, "character") + expect_type(preds$.pred_0, "double") + expect_type(preds$.pred_1, "double") + + exps <- as.data.frame(exps) + exps$.pred_class <- as.character(exps$.pred_class) + + rownames(preds) <- NULL + rownames(exps) <- NULL + + expect_equal( + preds, + exps + ) +}) diff --git a/vignettes/supported-models.Rmd b/vignettes/supported-models.Rmd index a60f1d1..d6fd089 100644 --- a/vignettes/supported-models.Rmd +++ b/vignettes/supported-models.Rmd @@ -32,7 +32,7 @@ tibble::tribble( ~parsnip, ~engine, ~numeric, ~class, ~prob, "`boost_tree()`", "`\"xgboost\"`", "✅", "✅", "✅", "`cubist_rules()`", "`\"Cubist\"`", "✅", "❌", "❌", - "`decision_tree()`", "`\"partykit\"`", "✅", "⚪", "⚪", + "`decision_tree()`", "`\"partykit\"`", "✅", "✅", "✅", "`linear_reg()`", "`\"lm\"`", "✅", "❌", "❌", "`linear_reg()`", "`\"glmnet\"`", "⚪", "❌", "❌", "`logistic_reg()`", "`\"glm\"`", "❌", "✅", "✅",