Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing value handling #687

Merged
merged 15 commits into from
Oct 28, 2024
Merged
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: ranger
Type: Package
Title: A Fast Implementation of Random Forests
Version: 0.16.3
Date: 2024-08-20
Version: 0.16.4
Date: 2024-10-28
Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb]
Maintainer: Marvin N. Wright <[email protected]>
Description: A fast implementation of Random Forests, particularly suited for high
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@

# ranger 0.16.4
* Add handling of missing values for classification and regression

# ranger 0.16.3
* Fix a bug for always.split.variables (for some settings)

Expand Down
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, poisson_tau, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest) {
.Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, poisson_tau, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest)
rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, poisson_tau, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest, any_na) {
.Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, poisson_tau, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest, any_na)
}

numSmaller <- function(values, reference) {
Expand Down
16 changes: 8 additions & 8 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
}

## Check for old ranger version
if (length(forest$child.nodeIDs) != forest$num.trees || length(forest$child.nodeIDs[[1]]) != 2) {
if (length(forest$child.nodeIDs) != forest$num.trees || length(forest$child.nodeIDs[[1]]) < 2 || length(forest$child.nodeIDs[[1]]) > 3) {
stop("Error: Invalid forest object. Is the forest grown in ranger version <0.3.9? Try to predict with the same version the forest was grown.")
}
if (!is.null(forest$dependent.varID)) {
Expand Down Expand Up @@ -185,12 +185,12 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
if (!is.matrix(x) & !inherits(x, "Matrix")) {
x <- data.matrix(x)
}

## Check missing values
if (any(is.na(x))) {
offending_columns <- colnames(x)[colSums(is.na(x)) > 0]
stop("Missing data in columns: ",
paste0(offending_columns, collapse = ", "), ".", call. = FALSE)
## Missing values
if (anyNA(x)) {
any.na <- TRUE
} else {
any.na <- FALSE
}

## Num threads
Expand Down Expand Up @@ -281,7 +281,7 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
prediction.type, num.random.splits, sparse.x, use.sparse.data,
order.snps, oob.error, max.depth, inbag, use.inbag,
regularization.factor, use.regularization.factor, regularization.usedepth,
node.stats, time.interest, use.time.interest)
node.stats, time.interest, use.time.interest, any.na)

if (length(result) == 0) {
stop("User interrupt or internal error.")
Expand Down
55 changes: 45 additions & 10 deletions R/ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@
##' If \code{regularization.usedepth=TRUE}, \eqn{f^d} is used, where \emph{f} is the regularization factor and \emph{d} the depth of the node.
##' If regularization is used, multithreading is deactivated because all trees need access to the list of variables that are already included in the model.
##'
##' Missing values can be internally handled by setting \code{na.action = "na.learn"} (default), by omitting observations with missing values with \code{na.action = "na.omit"} or by stopping if missing values are found with \code{na.action = "na.fail"}.
##' With \code{na.action = "na.learn"}, missing values are ignored for calculating an initial split criterion value (i.e., decrease of impurity). Then for the best split, all missings are tried in both child nodes and the choice is made based again on the split criterion value.
##' For prediction, this direction is saved as the "default" direction. If a missing occurs in prediction at a node where there is no default direction, it goes left.
##'
##' For a large number of variables and data frames as input data the formula interface can be slow or impossible to use.
##' Alternatively \code{dependent.variable.name} (and \code{status.variable.name} for survival) or \code{x} and \code{y} can be used.
##' Use \code{x} and \code{y} with a matrix for \code{x} to avoid conversions and save memory.
Expand Down Expand Up @@ -142,6 +146,7 @@
##' @param verbose Show computation status and estimated runtime.
##' @param node.stats Save node statistics. Set to \code{TRUE} to save prediction, number of observations and split statistics for each node.
##' @param seed Random seed. Default is \code{NULL}, which generates the seed from \code{R}. Set to \code{0} to ignore the \code{R} seed.
##' @param na.action Handling of missing values. Set to "na.learn" to internally handle missing values (default, see below), to "na.omit" to omit observations with missing values and to "na.fail" to stop if missing values are found.
##' @param dependent.variable.name Name of dependent variable, needed if no formula given. For survival forests this is the time variable.
##' @param status.variable.name Name of status variable, only applicable to survival data and needed if no formula given. Use 1 for event and 0 for censoring.
##' @param classification Set to \code{TRUE} to grow a classification forest. Only needed if the data is a matrix or the response numeric.
Expand Down Expand Up @@ -250,7 +255,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
keep.inbag = FALSE, inbag = NULL, holdout = FALSE,
quantreg = FALSE, time.interest = NULL, oob.error = TRUE,
num.threads = NULL, save.memory = FALSE,
verbose = TRUE, node.stats = FALSE, seed = NULL,
verbose = TRUE, node.stats = FALSE, seed = NULL, na.action = "na.learn",
dependent.variable.name = NULL, status.variable.name = NULL,
classification = NULL, x = NULL, y = NULL, ...) {

Expand Down Expand Up @@ -325,13 +330,37 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
}

## Check missing values
if (any(is.na(x))) {
offending_columns <- colnames(x)[colSums(is.na(x)) > 0]
stop("Missing data in columns: ",
paste0(offending_columns, collapse = ", "), ".", call. = FALSE)
}
if (any(is.na(y))) {
stop("Missing data in dependent variable.", call. = FALSE)
any.na <- FALSE
if (na.action == "na.fail") {
if (anyNA(x)) {
offending_columns <- colnames(x)[colSums(is.na(x)) > 0]
stop("Error: Missing data in columns: ",
paste0(offending_columns, collapse = ", "), ".", call. = FALSE)
}
if (anyNA(y)) {
stop("Error: Missing data in dependent variable.", call. = FALSE)
}
} else if (na.action == "na.omit") {
if (anyNA(x)) {
idx_keep <- stats::complete.cases(x)
x <- x[idx_keep, , drop = FALSE]
y <- y[idx_keep, drop = FALSE]
if (nrow(x) < 1) {
stop("Error: No observations left after removing missing values.")
}
}
} else if (na.action == "na.learn") {
if (anyNA(y)) {
stop("Error: Missing data in dependent variable.", call. = FALSE)
}
if (anyNA(x)) {
any.na <- TRUE
if (!is.null(splitrule) && !(splitrule %in% c("gini", "variance"))) {
stop("Error: Missing value handling currently only implemented for gini and variance splitrules.")
}
}
} else {
stop("Error: Invalid value for na.action. Use 'na.learn', 'na.omit' or 'na.fail'.")
}

## Check response levels
Expand Down Expand Up @@ -364,6 +393,11 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
stop("Error: Unsupported type of dependent variable.")
}

## No missing value handling for survival yet
if (any.na & treetype == 5) {
stop("Error: Missing value handling not yet implemented for survival forests.")
}

## Number of levels
if (treetype %in% c(1, 9)) {
if (is.factor(y)) {
Expand Down Expand Up @@ -444,7 +478,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
} else {
## Order factor levels by mean response
means <- sapply(levels(xx), function(y) {
mean(num.y[xx == y])
mean(num.y[xx == y], na.rm = TRUE)
})
levels.ordered <- as.character(levels(xx)[order(means)])
}
Expand Down Expand Up @@ -992,7 +1026,8 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
holdout, prediction.type, num.random.splits, sparse.x, use.sparse.data,
order.snps, oob.error, max.depth, inbag, use.inbag,
regularization.factor, use.regularization.factor, regularization.usedepth,
node.stats, time.interest, use.time.interest)
node.stats, time.interest, use.time.interest, any.na)


if (length(result) == 0) {
stop("User interrupt or internal error.")
Expand Down
2 changes: 1 addition & 1 deletion R/treeInfo.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ treeInfo <- function(object, tree = 1) {
if (forest$treetype == "Survival" && (is.null(forest$chf) || is.null(forest$unique.death.times))) {
stop("Error: Invalid forest object.")
}
if (length(forest$child.nodeIDs) != forest$num.trees || length(forest$child.nodeIDs[[1]]) != 2) {
if (length(forest$child.nodeIDs) != forest$num.trees || length(forest$child.nodeIDs[[1]]) < 2 || length(forest$child.nodeIDs[[1]]) > 3) {
stop("Error: Invalid forest object. Is the forest grown in ranger version <0.3.9? Try with the same version the forest was grown.")
}
if (!is.null(forest$dependent.varID)) {
Expand Down
7 changes: 7 additions & 0 deletions man/ranger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 32 additions & 7 deletions src/Data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace ranger {

Data::Data() :
num_rows(0), num_rows_rounded(0), num_cols(0), snp_data(0), num_cols_no_snp(0), externalData(true), index_data(0), max_num_unique_values(
0), order_snps(false) {
0), order_snps(false), any_na(false) {
}

size_t Data::getVariableID(const std::string& variable_name) const {
Expand Down Expand Up @@ -224,8 +224,19 @@ void Data::getAllValues(std::vector<double>& all_values, std::vector<size_t>& sa
for (size_t pos = start; pos < end; ++pos) {
all_values.push_back(get_x(sampleIDs[pos], varID));
}
std::sort(all_values.begin(), all_values.end());
if (any_na) {
std::sort(all_values.begin(), all_values.end(), less_nan<double>);
} else {
std::sort(all_values.begin(), all_values.end());
}
all_values.erase(std::unique(all_values.begin(), all_values.end()), all_values.end());

// Keep only one NaN value
if (any_na) {
while (all_values.size() >= 2 && std::isnan(all_values[all_values.size() - 2])) {
all_values.pop_back();
}
}
} else {
// If GWA data just use 0, 1, 2
all_values = std::vector<double>( { 0, 1, 2 });
Expand Down Expand Up @@ -262,17 +273,31 @@ void Data::sort() {
for (size_t row = 0; row < num_rows; ++row) {
unique_values[row] = get_x(row, col);
}
std::sort(unique_values.begin(), unique_values.end());

if (any_na) {
std::sort(unique_values.begin(), unique_values.end(), less_nan<double>);
} else {
std::sort(unique_values.begin(), unique_values.end());
}
unique_values.erase(unique(unique_values.begin(), unique_values.end()), unique_values.end());

// Get index of unique value
for (size_t row = 0; row < num_rows; ++row) {
size_t idx = std::lower_bound(unique_values.begin(), unique_values.end(), get_x(row, col))
- unique_values.begin();
size_t idx;
if (any_na) {
idx = std::lower_bound(unique_values.begin(), unique_values.end(), get_x(row, col), less_nan<double>) - unique_values.begin();
} else {
idx = std::lower_bound(unique_values.begin(), unique_values.end(), get_x(row, col)) - unique_values.begin();
}
index_data[col * num_rows + row] = idx;
}

// Save unique values

// Save unique values (keep NaN)
if (any_na) {
while (unique_values.size() >= 2 && std::isnan(unique_values[unique_values.size() - 2])) {
unique_values.pop_back();
}
}
unique_data_values.push_back(unique_values);
if (unique_values.size() > max_num_unique_values) {
max_num_unique_values = unique_values.size();
Expand Down
7 changes: 7 additions & 0 deletions src/Data.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ class Data {
}
return varID;
}

const bool hasNA() const {
return any_na;
}

// #nocov start (cannot be tested anymore because GenABEL not on CRAN)
const std::vector<std::vector<size_t>>& getSnpOrder() const {
Expand Down Expand Up @@ -221,6 +225,9 @@ class Data {
// Order of 0/1/2 for ordered splitting
std::vector<std::vector<size_t>> snp_order;
bool order_snps;

// Any missing values?
bool any_na;
};

} // namespace ranger
Expand Down
3 changes: 2 additions & 1 deletion src/DataRcpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ namespace ranger {
class DataRcpp: public Data {
public:
DataRcpp() = default;
DataRcpp(Rcpp::NumericMatrix& x, Rcpp::NumericMatrix& y, std::vector<std::string> variable_names, size_t num_rows, size_t num_cols) {
DataRcpp(Rcpp::NumericMatrix& x, Rcpp::NumericMatrix& y, std::vector<std::string> variable_names, size_t num_rows, size_t num_cols, bool any_na) {
this->x = x;
this->y = y;
this->variable_names = variable_names;
this->num_rows = num_rows;
this->num_cols = num_cols;
this->num_cols_no_snp = num_cols;
this->any_na = any_na;
}

DataRcpp(const DataRcpp&) = delete;
Expand Down
3 changes: 2 additions & 1 deletion src/DataSparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@
namespace ranger {

DataSparse::DataSparse(Eigen::SparseMatrix<double>& x, Rcpp::NumericMatrix& y, std::vector<std::string> variable_names, size_t num_rows,
size_t num_cols) :
size_t num_cols, bool any_na) :
x { }{
this->x.swap(x);
this->y = y;
this->variable_names = variable_names;
this->num_rows = num_rows;
this->num_cols = num_cols;
this->num_cols_no_snp = num_cols;
this->any_na = any_na;
}

} // namespace ranger
2 changes: 1 addition & 1 deletion src/DataSparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class DataSparse: public Data {
DataSparse() = default;

DataSparse(Eigen::SparseMatrix<double>& x, Rcpp::NumericMatrix& y, std::vector<std::string> variable_names, size_t num_rows,
size_t num_cols);
size_t num_cols, bool any_na);

DataSparse(const DataSparse&) = delete;
DataSparse& operator=(const DataSparse&) = delete;
Expand Down
Loading
Loading