diff --git a/DESCRIPTION b/DESCRIPTION index c2bb6bf..4449631 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -21,7 +21,6 @@ BugReports: https://github.com/mlverse/tabnet/issues Depends: R (>= 3.6) Imports: - cli, coro, data.tree, dials, @@ -46,6 +45,7 @@ Imports: withr, zeallot Suggests: + cli, knitr, modeldata, patchwork, diff --git a/R/dials.R b/R/dials.R index cc0d8e6..76cac0c 100644 --- a/R/dials.R +++ b/R/dials.R @@ -3,6 +3,11 @@ check_dials <- function() { stop("Package \"dials\" needed for this function to work. Please install it.", call. = FALSE) } +check_cli <- function() { + if (!requireNamespace("cli", quietly = TRUE)) + stop("Package \"cli\" needed for this function to work. Please install it.", call. = FALSE) +} + #' Parameters for the tabnet model @@ -127,57 +132,61 @@ num_steps <- function(range = c(3L, 10L), trans = NULL) { ) } -#' @noRd +#' Non-tunable parameters for the tabnet model +#' +#' @param range unused +#' @param trans unused +#' @rdname tabnet_non_tunable #' @export cat_emb_dim <- function(range = NULL, trans = NULL) { - check_dials() + check_cli() cli::cli_abort("{.var cat_emb_dim} cannot be used as a {.fun tune} parameter yet.") } -#' @noRd +#' @rdname tabnet_non_tunable #' @export checkpoint_epochs <- cat_emb_dim -#' @noRd +#' @rdname tabnet_non_tunable #' @export drop_last <- cat_emb_dim -#' @noRd +#' @rdname tabnet_non_tunable #' @export encoder_activation <- cat_emb_dim -#' @noRd +#' @rdname tabnet_non_tunable #' @export lr_scheduler <- cat_emb_dim -#' @noRd +#' @rdname tabnet_non_tunable #' @export mlp_activation <- cat_emb_dim -#' @noRd +#' @rdname tabnet_non_tunable #' @export mlp_hidden_multiplier <- cat_emb_dim -#' @noRd +#' @rdname tabnet_non_tunable #' @export num_independent_decoder <- cat_emb_dim -#' @noRd +#' @rdname tabnet_non_tunable #' @export num_shared_decoder <- cat_emb_dim -#' @noRd +#' @rdname tabnet_non_tunable #' @export optimizer <- cat_emb_dim -#' @noRd +#' @rdname tabnet_non_tunable #' @export penalty <- cat_emb_dim -#' @noRd +#' @rdname tabnet_non_tunable #' @export verbose <- cat_emb_dim -#' @noRd +#' @rdname tabnet_non_tunable #' @export virtual_batch_size <- cat_emb_dim diff --git a/R/parsnip.R b/R/parsnip.R index 4b4e48a..62df1b9 100644 --- a/R/parsnip.R +++ b/R/parsnip.R @@ -427,6 +427,11 @@ add_parsnip_tabnet <- function() { #' for this model are "unknown", "regression", or "classification". #' @inheritParams tabnet_config #' @inheritParams tabnet_fit +#' @param rate_decay multiplies the initial learning rate by `rate_decay` every +#' `rate_step_size` epochs. Unused if `lr_scheduler` is a `torch::lr_scheduler` +#' or `NULL`. +#' @param rate_step_size the learning rate scheduler step size. Unused if +#' `lr_scheduler` is a `torch::lr_scheduler` or `NULL`. #' #' @inheritSection tabnet_fit Threading #' @seealso tabnet_fit diff --git a/inst/WORDLIST b/inst/WORDLIST index dc00692..3978496 100644 --- a/inst/WORDLIST +++ b/inst/WORDLIST @@ -15,7 +15,6 @@ Pretrain Sercan TabNet TabNet's -adam ai al ames @@ -37,6 +36,7 @@ ggplot interpretable mse nn +num orginal overfit overfits @@ -49,4 +49,5 @@ sparsemax subprocesses th tidymodels +tunable zeallot diff --git a/man/tabnet.Rd b/man/tabnet.Rd index 0705d66..37106d6 100644 --- a/man/tabnet.Rd +++ b/man/tabnet.Rd @@ -113,6 +113,13 @@ decays the learning rate by \code{lr_decay} when no improvement after \code{step It can also be a \link[torch:lr_scheduler]{torch::lr_scheduler} function that only takes the optimizer as parameter. The \code{step} method is called once per epoch.} +\item{rate_decay}{multiplies the initial learning rate by \code{rate_decay} every +\code{rate_step_size} epochs. Unused if \code{lr_scheduler} is a \code{torch::lr_scheduler} +or \code{NULL}.} + +\item{rate_step_size}{the learning rate scheduler step size. Unused if +\code{lr_scheduler} is a \code{torch::lr_scheduler} or \code{NULL}.} + \item{checkpoint_epochs}{checkpoint model weights and architecture every \code{checkpoint_epochs}. (default is 10). This may cause large memory usage. Use \code{0} to disable checkpoints.} diff --git a/man/tabnet_non_tunable.Rd b/man/tabnet_non_tunable.Rd new file mode 100644 index 0000000..0fa7487 --- /dev/null +++ b/man/tabnet_non_tunable.Rd @@ -0,0 +1,52 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/dials.R +\name{cat_emb_dim} +\alias{cat_emb_dim} +\alias{checkpoint_epochs} +\alias{drop_last} +\alias{encoder_activation} +\alias{lr_scheduler} +\alias{mlp_activation} +\alias{mlp_hidden_multiplier} +\alias{num_independent_decoder} +\alias{num_shared_decoder} +\alias{optimizer} +\alias{penalty} +\alias{verbose} +\alias{virtual_batch_size} +\title{Non-tunable parameters for the tabnet model} +\usage{ +cat_emb_dim(range = NULL, trans = NULL) + +checkpoint_epochs(range = NULL, trans = NULL) + +drop_last(range = NULL, trans = NULL) + +encoder_activation(range = NULL, trans = NULL) + +lr_scheduler(range = NULL, trans = NULL) + +mlp_activation(range = NULL, trans = NULL) + +mlp_hidden_multiplier(range = NULL, trans = NULL) + +num_independent_decoder(range = NULL, trans = NULL) + +num_shared_decoder(range = NULL, trans = NULL) + +optimizer(range = NULL, trans = NULL) + +penalty(range = NULL, trans = NULL) + +verbose(range = NULL, trans = NULL) + +virtual_batch_size(range = NULL, trans = NULL) +} +\arguments{ +\item{range}{unused} + +\item{trans}{unused} +} +\description{ +Non-tunable parameters for the tabnet model +}