Skip to content

Commit

Permalink
add num_workers option to dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
cregouby committed Jan 30, 2022
1 parent b473fab commit a16b4cc
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
18 changes: 13 additions & 5 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ resolve_data <- function(x, y) {
#' @param early_stopping_tolerance Minimum relative improvement to reset the patience counter.
#' 0.01 for 1% tolerance (default 0)
#' @param early_stopping_patience Number of epochs without improving until stopping training. (default=5)
#' @param num_workers (int, optional): how many subprocesses to use for data
#' loading. 0 means that the data will be loaded in the main process.
#' (default: `0`)
#' @return A named list with all hyperparameters of the TabNet implementation.
#'
#' @export
Expand Down Expand Up @@ -139,7 +142,8 @@ tabnet_config <- function(batch_size = 256,
importance_sample_size = NULL,
early_stopping_monitor = "auto",
early_stopping_tolerance = 0,
early_stopping_patience = 0L) {
early_stopping_patience = 0L,
num_workers=0L) {
if (is.null(decision_width) && is.null(attention_width)) {
decision_width <- 8 # default is 8
}
Expand Down Expand Up @@ -181,7 +185,8 @@ tabnet_config <- function(batch_size = 256,
early_stopping_monitor = resolve_early_stop_monitor(early_stopping_monitor, valid_split),
early_stopping_tolerance = early_stopping_tolerance,
early_stopping_patience = early_stopping_patience,
early_stopping = !(early_stopping_tolerance==0 || early_stopping_patience==0)
early_stopping = !(early_stopping_tolerance==0 || early_stopping_patience==0),
num_workers = num_workers
)
}

Expand Down Expand Up @@ -379,7 +384,8 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s
torch::tensor_dataset(x = train_mat$x, na_mask = train_mat$x_na_mask, y = train_mat$y),
batch_size = config$batch_size,
drop_last = config$drop_last,
shuffle = TRUE
shuffle = TRUE ,
num_workers = config$num_workers
)

# validation data
Expand All @@ -389,7 +395,8 @@ tabnet_train_supervised <- function(obj, x, y, config = tabnet_config(), epoch_s
torch::tensor_dataset(x = valid_mat$x, na_mask = valid_mat$x_na_mask, y = valid_mat$y),
batch_size = config$batch_size,
drop_last = FALSE,
shuffle = FALSE
shuffle = FALSE ,
num_workers = config$num_workers
)
}

Expand Down Expand Up @@ -548,7 +555,8 @@ predict_impl <- function(obj, x, batch_size = 1e5) {
torch::tensor_dataset(x = predict_mat$x, na_mask = predict_mat$x_na_mask),
batch_size = batch_size,
drop_last = FALSE,
shuffle = FALSE
shuffle = FALSE ,
num_workers = config$num_workers
)
coro::loop(for (batch in predict_dl) {
yhat <- c(yhat, network(batch$x, batch$na_mask)[[1]])
Expand Down
7 changes: 6 additions & 1 deletion man/tabnet_config.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit a16b4cc

Please sign in to comment.