diff --git a/DESCRIPTION b/DESCRIPTION index 79dc896d..ed921305 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,60 +1,68 @@ Package: tabnet Title: Fit 'TabNet' Models for Classification and Regression -Version: 0.6.0 +Version: 0.5.0 Authors@R: c( - person(given = "Daniel", family = "Falbel", role = c("aut"), email = "daniel@rstudio.com"), - person(family = "RStudio", role = c("cph")), - person(given = "Christophe", family = "Regouby", role = c("cre", "ctb"), email = "christophe.regouby@free.fr"), - person(given = "Egill", family = "Fridgeirsson", role = c("ctb")), - person(given = "Philipp", family = "Haarmeyer", role = c("ctb")), - person(given = "Sven", family = "Verweij", role = c("ctb"), comment = c(ORCID = "0000-0002-5573-3952")) - ) -Description: Implements the 'TabNet' model by Sercan O. Arik et al. (2019) - with 'Coherent Hierarchical Multi-label Classification Networks' by Giunchiglia et al. - and provides a consistent interface for fitting and creating - predictions. It's also fully compatible with the 'tidymodels' ecosystem. + person("Daniel", "Falbel", , "daniel@rstudio.com", role = "aut"), + person(, "RStudio", role = "cph"), + person("Christophe", "Regouby", , "christophe.regouby@free.fr", role = c("cre", "ctb")), + person("Egill", "Fridgeirsson", role = "ctb"), + person("Philipp", "Haarmeyer", role = "ctb"), + person("Sven", "Verweij", role = "ctb", + comment = c(ORCID = "0000-0002-5573-3952")) + ) +Description: Implements the 'TabNet' model by Sercan O. Arik et al. (2019) + with 'Coherent Hierarchical Multi-label + Classification Networks' by Giunchiglia et al. and + provides a consistent interface for fitting and creating predictions. + It's also fully compatible with the 'tidymodels' ecosystem. License: MIT + file LICENSE -Encoding: UTF-8 -Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 URL: https://mlverse.github.io/tabnet/, https://github.com/mlverse/tabnet BugReports: https://github.com/mlverse/tabnet/issues Depends: R (>= 3.6) Imports: - torch (>= 0.4.0), + coro, + data.tree, + dials, + dplyr, + ggplot2, hardhat (>= 1.3.0), magrittr, + Matrix, + methods, + parsnip, progress, + purrr, rlang, - methods, - dplyr, + stats, + stringr, tibble, tidyr, - coro, + torch (>= 0.4.0), + tune, + utils, vctrs, + vip, + withr, zeallot Suggests: - testthat (>= 3.0.0), - data.tree, - Matrix, + knitr, modeldata, + patchwork, recipes, - rsample, - parsnip, - dials, - withr, - knitr, rmarkdown, - vip, + rsample, + testthat (>= 3.0.0), + tidymodels, tidyverse, - ggplot2, - purrr, - stringr, - tune, + visdat, workflows, yardstick -VignetteBuilder: knitr +VignetteBuilder: + knitr Config/testthat/edition: 3 Config/testthat/parallel: false Config/testthat/start-first: interface, explain, params +Encoding: UTF-8 +Roxygen: list(markdown = TRUE) +RoxygenNote: 7.3.1 diff --git a/README.md b/README.md index 39350ffc..1536bb0a 100644 --- a/README.md +++ b/README.md @@ -98,13 +98,16 @@ cbind(test, predict(fit, test)) %>% #> 1 accuracy binary 0.837 #> 2 precision binary 0.837 #> 3 recall binary 1 +``` + +``` r cbind(test, predict(fit, test, type = "prob")) %>% roc_auc(Attrition, .pred_No) #> # A tibble: 1 × 3 #> .metric .estimator .estimate #> -#> 1 roc_auc binary 0.548 +#> 1 roc_auc binary 0.546 ``` ## Explain model on test-set with attention map diff --git a/man/figures/README-model-explain-1.png b/man/figures/README-model-explain-1.png index 36860e83..4f2715bb 100644 Binary files a/man/figures/README-model-explain-1.png and b/man/figures/README-model-explain-1.png differ diff --git a/man/figures/README-model-fit-1.png b/man/figures/README-model-fit-1.png index 022fe385..8001b1f6 100644 Binary files a/man/figures/README-model-fit-1.png and b/man/figures/README-model-fit-1.png differ diff --git a/man/figures/README-step-explain-1.png b/man/figures/README-step-explain-1.png index 97baea3c..e5c49b0d 100644 Binary files a/man/figures/README-step-explain-1.png and b/man/figures/README-step-explain-1.png differ diff --git a/man/figures/README-step-pretrain-1.png b/man/figures/README-step-pretrain-1.png index 59a99fbf..b8332f62 100644 Binary files a/man/figures/README-step-pretrain-1.png and b/man/figures/README-step-pretrain-1.png differ