From 81bf7d80218443523ce78fe00d2c2bbf4a528bdd Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Thu, 24 Aug 2023 10:28:47 +0200 Subject: [PATCH 01/11] quick and dirty missing value handling for classification --- R/predict.R | 9 +-- R/ranger.R | 33 +++++++--- R/treeInfo.R | 2 +- man/ranger.Rd | 3 + src/Data.cpp | 21 ++++-- src/Tree.cpp | 51 +++++++++++---- src/Tree.h | 5 +- src/TreeClassification.cpp | 117 +++++++++++++++++++++++++++------- src/utility.h | 16 +++++ tests/testthat/test_predict.R | 4 +- tests/testthat/test_ranger.R | 4 +- 11 files changed, 204 insertions(+), 61 deletions(-) diff --git a/R/predict.R b/R/predict.R index ef1397e73..50e659a69 100644 --- a/R/predict.R +++ b/R/predict.R @@ -103,7 +103,7 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE, } ## Check for old ranger version - if (length(forest$child.nodeIDs) != forest$num.trees || length(forest$child.nodeIDs[[1]]) != 2) { + if (length(forest$child.nodeIDs) != forest$num.trees || length(forest$child.nodeIDs[[1]]) < 2 || length(forest$child.nodeIDs[[1]]) > 3) { stop("Error: Invalid forest object. Is the forest grown in ranger version <0.3.9? Try to predict with the same version the forest was grown.") } if (!is.null(forest$dependent.varID)) { @@ -183,13 +183,6 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE, x <- data.matrix(x) } - ## Check missing values - if (any(is.na(x))) { - offending_columns <- colnames(x)[colSums(is.na(x)) > 0] - stop("Missing data in columns: ", - paste0(offending_columns, collapse = ", "), ".", call. = FALSE) - } - ## Num threads ## Default 0 -> detect from system in C++. if (is.null(num.threads)) { diff --git a/R/ranger.R b/R/ranger.R index 54a18342f..395712df6 100644 --- a/R/ranger.R +++ b/R/ranger.R @@ -120,6 +120,7 @@ ##' @param save.memory Use memory saving (but slower) splitting mode. No effect for survival and GWAS data. Warning: This option slows down the tree growing, use only if you encounter memory problems. ##' @param verbose Show computation status and estimated runtime. ##' @param seed Random seed. Default is \code{NULL}, which generates the seed from \code{R}. Set to \code{0} to ignore the \code{R} seed. +##' @param na.action Handling of missing values. Set to "na.learn" to internally handle missing values (default, see below), to "na.omit" to omit observations with missing values and to "na.fail" to stop if missing values are found. ##' @param dependent.variable.name Name of dependent variable, needed if no formula given. For survival forests this is the time variable. ##' @param status.variable.name Name of status variable, only applicable to survival data and needed if no formula given. Use 1 for event and 0 for censoring. ##' @param classification Set to \code{TRUE} to grow a classification forest. Only needed if the data is a matrix or the response numeric. @@ -224,7 +225,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, keep.inbag = FALSE, inbag = NULL, holdout = FALSE, quantreg = FALSE, oob.error = TRUE, num.threads = NULL, save.memory = FALSE, - verbose = TRUE, seed = NULL, + verbose = TRUE, seed = NULL, na.action = "na.learn", dependent.variable.name = NULL, status.variable.name = NULL, classification = NULL, x = NULL, y = NULL, ...) { @@ -292,13 +293,25 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, } ## Check missing values - if (any(is.na(x))) { - offending_columns <- colnames(x)[colSums(is.na(x)) > 0] - stop("Missing data in columns: ", - paste0(offending_columns, collapse = ", "), ".", call. = FALSE) - } - if (any(is.na(y))) { - stop("Missing data in dependent variable.", call. = FALSE) + if (na.action == "na.fail") { + if (any(is.na(x))) { + offending_columns <- colnames(x)[colSums(is.na(x)) > 0] + stop("Missing data in columns: ", + paste0(offending_columns, collapse = ", "), ".", call. = FALSE) + } + if (any(is.na(y))) { + stop("Missing data in dependent variable.", call. = FALSE) + } + } else if (na.action == "na.omit") { + # TODO: Implement na.omit + stop("na.omit not implemented yet.") + } else if (na.action == "na.learn") { + # TODO: fix y + if (any(is.na(y))) { + stop("Missing data in dependent variable.", call. = FALSE) + } + } else { + stop("Error: Invalid value for na.action. Use 'na.learn', 'na.omit' or 'na.fail'.") } ## Check response levels @@ -391,6 +404,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, ## Don't order if only one level levels.ordered <- levels(xx) } else if (inherits(y, "Surv")) { + # TODO: Fix missings here ## Use median survival if available or largest quantile available in all strata if median not available levels.ordered <- largest.quantile(y ~ xx) @@ -398,11 +412,12 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, levels.missing <- setdiff(levels(xx), levels.ordered) levels.ordered <- c(levels.missing, levels.ordered) } else if (is.factor(y) & nlevels(y) > 2) { + # TODO: Fix missings here levels.ordered <- pca.order(y = y, x = xx) } else { ## Order factor levels by mean response means <- sapply(levels(xx), function(y) { - mean(num.y[xx == y]) + mean(num.y[xx == y], na.rm = TRUE) }) levels.ordered <- as.character(levels(xx)[order(means)]) } diff --git a/R/treeInfo.R b/R/treeInfo.R index f0a54d362..250fe66c0 100644 --- a/R/treeInfo.R +++ b/R/treeInfo.R @@ -76,7 +76,7 @@ treeInfo <- function(object, tree = 1) { if (forest$treetype == "Survival" && (is.null(forest$chf) || is.null(forest$unique.death.times))) { stop("Error: Invalid forest object.") } - if (length(forest$child.nodeIDs) != forest$num.trees || length(forest$child.nodeIDs[[1]]) != 2) { + if (length(forest$child.nodeIDs) != forest$num.trees || length(forest$child.nodeIDs[[1]]) < 2 || length(forest$child.nodeIDs[[1]]) > 3) { stop("Error: Invalid forest object. Is the forest grown in ranger version <0.3.9? Try with the same version the forest was grown.") } if (!is.null(forest$dependent.varID)) { diff --git a/man/ranger.Rd b/man/ranger.Rd index 63d6d395e..8bcd5f37d 100644 --- a/man/ranger.Rd +++ b/man/ranger.Rd @@ -39,6 +39,7 @@ ranger( save.memory = FALSE, verbose = TRUE, seed = NULL, + na.action = "na.learn", dependent.variable.name = NULL, status.variable.name = NULL, classification = NULL, @@ -116,6 +117,8 @@ ranger( \item{seed}{Random seed. Default is \code{NULL}, which generates the seed from \code{R}. Set to \code{0} to ignore the \code{R} seed.} +\item{na.action}{Handling of missing values. Set to "na.learn" to internally handle missing values (default, see below), to "na.omit" to omit observations with missing values and to "na.fail" to stop if missing values are found.} + \item{dependent.variable.name}{Name of dependent variable, needed if no formula given. For survival forests this is the time variable.} \item{status.variable.name}{Name of status variable, only applicable to survival data and needed if no formula given. Use 1 for event and 0 for censoring.} diff --git a/src/Data.cpp b/src/Data.cpp index 0363e7fd9..52a5d091c 100644 --- a/src/Data.cpp +++ b/src/Data.cpp @@ -214,6 +214,7 @@ bool Data::loadFromFileOther(std::ifstream& input_file, std::string header_line, } // #nocov end +// TODO: Check if slower void Data::getAllValues(std::vector& all_values, std::vector& sampleIDs, size_t varID, size_t start, size_t end) const { @@ -224,8 +225,13 @@ void Data::getAllValues(std::vector& all_values, std::vector& sa for (size_t pos = start; pos < end; ++pos) { all_values.push_back(get_x(sampleIDs[pos], varID)); } - std::sort(all_values.begin(), all_values.end()); + std::sort(all_values.begin(), all_values.end(), less_nan); all_values.erase(std::unique(all_values.begin(), all_values.end()), all_values.end()); + + // Keep only one NaN value + while (all_values.size() >= 2 && std::isnan(all_values[all_values.size() - 2])) { + all_values.pop_back(); + } } else { // If GWA data just use 0, 1, 2 all_values = std::vector( { 0, 1, 2 }); @@ -262,17 +268,22 @@ void Data::sort() { for (size_t row = 0; row < num_rows; ++row) { unique_values[row] = get_x(row, col); } - std::sort(unique_values.begin(), unique_values.end()); + + // TODO: Check if this makes it slower if no NA. If yes, check in the beginning if there is any NA and overload a function based on that? Or just use Inf from the beginning? + std::sort(unique_values.begin(), unique_values.end(), less_nan); unique_values.erase(unique(unique_values.begin(), unique_values.end()), unique_values.end()); // Get index of unique value for (size_t row = 0; row < num_rows; ++row) { - size_t idx = std::lower_bound(unique_values.begin(), unique_values.end(), get_x(row, col)) + size_t idx = std::lower_bound(unique_values.begin(), unique_values.end(), get_x(row, col), less_nan) - unique_values.begin(); index_data[col * num_rows + row] = idx; } - - // Save unique values + + // Save unique values (keep NaN) + while (unique_values.size() >= 2 && std::isnan(unique_values[unique_values.size() - 2])) { + unique_values.pop_back(); + } unique_data_values.push_back(unique_values); if (unique_values.size() > max_num_unique_values) { max_num_unique_values = unique_values.size(); diff --git a/src/Tree.cpp b/src/Tree.cpp index c6ef6303f..4adef3d83 100644 --- a/src/Tree.cpp +++ b/src/Tree.cpp @@ -51,6 +51,7 @@ void Tree::init(const Data* data, uint mtry, size_t num_samples, uint seed, std: // Create root node, assign bootstrap sample and oob samples child_nodeIDs.push_back(std::vector()); child_nodeIDs.push_back(std::vector()); + child_nodeIDs.push_back(std::vector()); createEmptyNode(); // Initialize random number generator and set seed @@ -176,13 +177,24 @@ void Tree::predict(const Data* prediction_data, bool oob_prediction) { double value = prediction_data->get_x(sample_idx, split_varID); if (prediction_data->isOrderedVariable(split_varID)) { - if (value <= split_values[nodeID]) { - // Move to left child - nodeID = child_nodeIDs[0][nodeID]; + if (std::isnan(value)) { + if (child_nodeIDs[2][nodeID] > 0) { + // Move to default child + nodeID = child_nodeIDs[2][nodeID]; + } else { + // Move to left child + nodeID = child_nodeIDs[0][nodeID]; + } } else { - // Move to right child - nodeID = child_nodeIDs[1][nodeID]; + if (value <= split_values[nodeID]) { + // Move to left child + nodeID = child_nodeIDs[0][nodeID]; + } else { + // Move to right child + nodeID = child_nodeIDs[1][nodeID]; + } } + } else { size_t factorID = floor(value) - 1; size_t splitID = floor(split_values[nodeID]); @@ -309,6 +321,8 @@ bool Tree::splitNode(size_t nodeID) { std::vector possible_split_varIDs; createPossibleSplitVarSubset(possible_split_varIDs); + nan_go_right = false; + // Call subclass method, sets split_varIDs and split_values bool stop = splitNodeInternal(nodeID, possible_split_varIDs); if (stop) { @@ -339,13 +353,27 @@ bool Tree::splitNode(size_t nodeID) { size_t pos = start_pos[nodeID]; while (pos < start_pos[right_child_nodeID]) { size_t sampleID = sampleIDs[pos]; - if (data->get_x(sampleID, split_varID) <= split_value) { - // If going to left, do nothing - ++pos; + + if (std::isnan(data->get_x(sampleID, split_varID))) { + if (nan_go_right) { + // If going to right, move to right end + --start_pos[right_child_nodeID]; + std::swap(sampleIDs[pos], sampleIDs[start_pos[right_child_nodeID]]); + child_nodeIDs[2][nodeID] = right_child_nodeID; + } else { + // If going to left, do nothing + ++pos; + child_nodeIDs[2][nodeID] = left_child_nodeID; + } } else { - // If going to right, move to right end - --start_pos[right_child_nodeID]; - std::swap(sampleIDs[pos], sampleIDs[start_pos[right_child_nodeID]]); + if (data->get_x(sampleID, split_varID) <= split_value) { + // If going to left, do nothing + ++pos; + } else { + // If going to right, move to right end + --start_pos[right_child_nodeID]; + std::swap(sampleIDs[pos], sampleIDs[start_pos[right_child_nodeID]]); + } } } } else { @@ -382,6 +410,7 @@ void Tree::createEmptyNode() { split_values.push_back(0); child_nodeIDs[0].push_back(0); child_nodeIDs[1].push_back(0); + child_nodeIDs[2].push_back(0); start_pos.push_back(0); end_pos.push_back(0); diff --git a/src/Tree.h b/src/Tree.h index 3acbfa20f..c95d13428 100644 --- a/src/Tree.h +++ b/src/Tree.h @@ -179,7 +179,7 @@ class Tree { // For terminal nodes the prediction value is saved here std::vector split_values; - // Vector of left and right child node IDs, 0 for no child + // Vector of left and right child node IDs, 0 for no child, third value for default child std::vector> child_nodeIDs; // All sampleIDs in the tree, will be re-ordered while splitting @@ -230,6 +230,9 @@ class Tree { uint max_depth; uint depth; size_t last_left_nodeID; + + // Should NaNs go to right child for the current split? + bool nan_go_right; }; } // namespace ranger diff --git a/src/TreeClassification.cpp b/src/TreeClassification.cpp index 7353f47ab..3631a0543 100644 --- a/src/TreeClassification.cpp +++ b/src/TreeClassification.cpp @@ -238,15 +238,37 @@ void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, s const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, double& best_decrease, const std::vector& possible_split_values, std::vector& counter_per_class, std::vector& counter) { + + // Counters without NaNs + std::vector class_counts_nan(num_classes, 0); + size_t num_samples_node_nan = 0; - for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { - size_t sampleID = sampleIDs[pos]; - uint sample_classID = (*response_classIDs)[sampleID]; - size_t idx = std::lower_bound(possible_split_values.begin(), possible_split_values.end(), - data->get_x(sampleID, varID)) - possible_split_values.begin(); - - ++counter_per_class[idx * num_classes + sample_classID]; - ++counter[idx]; + size_t last_index = possible_split_values.size() - 1; + if (std::isnan(possible_split_values[last_index])) { + for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { + size_t sampleID = sampleIDs[pos]; + uint sample_classID = (*response_classIDs)[sampleID]; + + if (std::isnan(data->get_x(sampleID, varID))) { + ++num_samples_node_nan; + ++class_counts_nan[sample_classID]; + } else { + size_t idx = std::lower_bound(possible_split_values.begin(), possible_split_values.end(), + data->get_x(sampleID, varID)) - possible_split_values.begin(); + ++counter_per_class[idx * num_classes + sample_classID]; + ++counter[idx]; + } + } + } else { + for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { + size_t sampleID = sampleIDs[pos]; + uint sample_classID = (*response_classIDs)[sampleID]; + size_t idx = std::lower_bound(possible_split_values.begin(), possible_split_values.end(), + data->get_x(sampleID, varID)) - possible_split_values.begin(); + + ++counter_per_class[idx * num_classes + sample_classID]; + ++counter[idx]; + } } size_t n_left = 0; @@ -254,7 +276,7 @@ void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, s // Compute decrease of impurity for each split for (size_t i = 0; i < possible_split_values.size() - 1; ++i) { - + // Stop if nothing here if (counter[i] == 0) { continue; @@ -263,7 +285,7 @@ void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, s n_left += counter[i]; // Stop if right child empty - size_t n_right = num_samples_node - n_left; + size_t n_right = num_samples_node - num_samples_node_nan - n_left; if (n_right == 0) { break; } @@ -274,6 +296,8 @@ void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, s } double decrease; + double decrease_nanleft; + double decrease_nanright; if (splitrule == HELLINGER) { for (size_t j = 0; j < num_classes; ++j) { class_counts_left[j] += counter_per_class[i * num_classes + j]; @@ -292,16 +316,23 @@ void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, s // Sum of squares double sum_left = 0; double sum_right = 0; + double sum_left_withnan = 0; + double sum_right_withnan = 0; for (size_t j = 0; j < num_classes; ++j) { class_counts_left[j] += counter_per_class[i * num_classes + j]; - size_t class_count_right = class_counts[j] - class_counts_left[j]; + size_t class_count_right = class_counts[j] - class_counts_nan[j] - class_counts_left[j]; sum_left += (*class_weights)[j] * class_counts_left[j] * class_counts_left[j]; sum_right += (*class_weights)[j] * class_count_right * class_count_right; + + sum_left_withnan += (*class_weights)[j] * (class_counts_left[j] + class_counts_nan[j]) * (class_counts_left[j] + class_counts_nan[j]); + sum_right_withnan += (*class_weights)[j] * (class_count_right + class_counts_nan[j]) * (class_count_right + class_counts_nan[j]); } // Decrease of impurity decrease = sum_right / (double) n_right + sum_left / (double) n_left; + decrease_nanleft = sum_right / (double) n_right + sum_left_withnan / (double) (n_left + num_samples_node_nan); + decrease_nanright = sum_right_withnan / (double) (n_right + num_samples_node_nan) + sum_left / (double) n_left; } // Regularization @@ -313,6 +344,12 @@ void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, s best_value = (possible_split_values[i] + possible_split_values[i + 1]) / 2; best_varID = varID; best_decrease = decrease; + + if (decrease_nanright > decrease_nanleft) { + nan_go_right = true; + } else { + nan_go_right = false; + } // Use smaller value if average is numerically the same as the larger value if (best_value == possible_split_values[i + 1]) { @@ -330,15 +367,36 @@ void TreeClassification::findBestSplitValueLargeQ(size_t nodeID, size_t varID, s size_t num_unique = data->getNumUniqueDataValues(varID); std::fill_n(counter_per_class.begin(), num_unique * num_classes, 0); std::fill_n(counter.begin(), num_unique, 0); + + // Counters without NaNs + std::vector class_counts_nan(num_classes, 0); + size_t num_samples_node_nan = 0; // Count values - for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { - size_t sampleID = sampleIDs[pos]; - size_t index = data->getIndex(sampleID, varID); - size_t classID = (*response_classIDs)[sampleID]; - - ++counter[index]; - ++counter_per_class[index * num_classes + classID]; + size_t last_index = data->getNumUniqueDataValues(varID) - 1; + if (std::isnan(data->getUniqueDataValue(varID, last_index))) { + for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { + size_t sampleID = sampleIDs[pos]; + size_t index = data->getIndex(sampleID, varID); + size_t classID = (*response_classIDs)[sampleID]; + + if (index < last_index) { + ++counter[index]; + ++counter_per_class[index * num_classes + classID]; + } else { + ++num_samples_node_nan; + ++class_counts_nan[classID]; + } + } + } else { + for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { + size_t sampleID = sampleIDs[pos]; + size_t index = data->getIndex(sampleID, varID); + size_t classID = (*response_classIDs)[sampleID]; + + ++counter[index]; + ++counter_per_class[index * num_classes + classID]; + } } size_t n_left = 0; @@ -355,7 +413,7 @@ void TreeClassification::findBestSplitValueLargeQ(size_t nodeID, size_t varID, s n_left += counter[i]; // Stop if right child empty - size_t n_right = num_samples_node - n_left; + size_t n_right = num_samples_node - num_samples_node_nan - n_left; if (n_right == 0) { break; } @@ -366,6 +424,8 @@ void TreeClassification::findBestSplitValueLargeQ(size_t nodeID, size_t varID, s } double decrease; + double decrease_nanleft; + double decrease_nanright; if (splitrule == HELLINGER) { for (size_t j = 0; j < num_classes; ++j) { class_counts_left[j] += counter_per_class[i * num_classes + j]; @@ -384,21 +444,28 @@ void TreeClassification::findBestSplitValueLargeQ(size_t nodeID, size_t varID, s // Sum of squares double sum_left = 0; double sum_right = 0; + double sum_left_withnan = 0; + double sum_right_withnan = 0; for (size_t j = 0; j < num_classes; ++j) { class_counts_left[j] += counter_per_class[i * num_classes + j]; - size_t class_count_right = class_counts[j] - class_counts_left[j]; + size_t class_count_right = class_counts[j] - class_counts_nan[j] - class_counts_left[j]; sum_left += (*class_weights)[j] * class_counts_left[j] * class_counts_left[j]; sum_right += (*class_weights)[j] * class_count_right * class_count_right; + + sum_left_withnan += (*class_weights)[j] * (class_counts_left[j] + class_counts_nan[j]) * (class_counts_left[j] + class_counts_nan[j]); + sum_right_withnan += (*class_weights)[j] * (class_count_right + class_counts_nan[j]) * (class_count_right + class_counts_nan[j]); } // Decrease of impurity decrease = sum_right / (double) n_right + sum_left / (double) n_left; + decrease_nanleft = sum_right / (double) n_right + sum_left_withnan / (double) (n_left + num_samples_node_nan); + decrease_nanright = sum_right_withnan / (double) (n_right + num_samples_node_nan) + sum_left / (double) n_left; } // Regularization regularize(decrease, varID); - + // If better than before, use this if (decrease > best_decrease) { // Find next value in this node @@ -411,7 +478,13 @@ void TreeClassification::findBestSplitValueLargeQ(size_t nodeID, size_t varID, s best_value = (data->getUniqueDataValue(varID, i) + data->getUniqueDataValue(varID, j)) / 2; best_varID = varID; best_decrease = decrease; - + + if (decrease_nanright > decrease_nanleft) { + nan_go_right = true; + } else { + nan_go_right = false; + } + // Use smaller value if average is numerically the same as the larger value if (best_value == data->getUniqueDataValue(varID, j)) { best_value = data->getUniqueDataValue(varID, i); diff --git a/src/utility.h b/src/utility.h index 1460c3023..acca30be2 100644 --- a/src/utility.h +++ b/src/utility.h @@ -33,6 +33,22 @@ namespace ranger { +/** + * Returns whether first value (a) is less than second value (b). NaN treated as Inf. + * @param a First value to compare + * @param b Second value to compare + */ +template +inline bool less_nan(const T& a, const T& b) { + if (std::isnan(a)) { + return false; + } else if (std::isnan(b)) { + return true; + } else { + return a < b; + } +} + /** * Split sequence start..end in num_parts parts with sizes as equal as possible. * @param result Result vector of size num_parts+1. Ranges for the parts are then result[0]..result[1]-1, result[1]..result[2]-1, .. diff --git a/tests/testthat/test_predict.R b/tests/testthat/test_predict.R index f6059ab14..714da0696 100644 --- a/tests/testthat/test_predict.R +++ b/tests/testthat/test_predict.R @@ -38,13 +38,13 @@ test_that("Prediction works correctly if dependent variable is not first or last expect_gte(mean(predictions(predict(rf, dat[, -3])) == dat$Species), 0.9) }) -test_that("Missing value columns detected in predict", { +test_that("Prediction works with missing values", { rf <- ranger(Species ~ ., iris, num.trees = 5, write.forest = TRUE) dat <- iris dat[4, 4] <- NA dat[25, 1] <- NA - expect_error(predict(rf, dat), "Missing data in columns: Sepal.Length, Petal.Width.") + expect_silent(predict(rf, dat)) }) test_that("If num.trees set, these number is used for predictions", { diff --git a/tests/testthat/test_ranger.R b/tests/testthat/test_ranger.R index 0f2c0118a..4671e3abe 100644 --- a/tests/testthat/test_ranger.R +++ b/tests/testthat/test_ranger.R @@ -164,10 +164,10 @@ test_that("OOB error is correct for 1 tree, regression", { expect_equal(rf$prediction.error, mean((dat$y - rf$predictions)^2, na.rm = TRUE)) }) -test_that("Missing value columns detected in training", { +test_that("Training works with missing values in x but not in y", { dat <- iris dat[25, 1] <- NA - expect_error(ranger(Species ~ ., dat, num.trees = 5), "Missing data in columns: Sepal.Length") + expect_silent(ranger(Species ~ ., dat, num.trees = 5)) dat <- iris dat[4, 5] <- NA From a457fc6d3844a259122fc7968df9770d88622400 Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Mon, 1 Jul 2024 16:41:10 +0200 Subject: [PATCH 02/11] only use NA code if any NAs --- R/RcppExports.R | 4 ++-- R/predict.R | 9 ++++++++- R/ranger.R | 20 +++++++++++--------- src/Data.cpp | 36 +++++++++++++++++++++++++----------- src/Data.h | 7 +++++++ src/DataRcpp.h | 3 ++- src/DataSparse.cpp | 3 ++- src/DataSparse.h | 2 +- src/RcppExports.cpp | 9 +++++---- src/Tree.cpp | 19 +++++++++++++------ src/rangerCpp.cpp | 6 +++--- 11 files changed, 79 insertions(+), 39 deletions(-) diff --git a/R/RcppExports.R b/R/RcppExports.R index dfac21ec6..7a024b68b 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -1,8 +1,8 @@ # Generated by using Rcpp::compileAttributes() -> do not edit by hand # Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 -rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest) { - .Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest) +rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest, any_na) { + .Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest, any_na) } numSmaller <- function(values, reference) { diff --git a/R/predict.R b/R/predict.R index e22668266..0d047a8b2 100644 --- a/R/predict.R +++ b/R/predict.R @@ -185,6 +185,13 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE, if (!is.matrix(x) & !inherits(x, "Matrix")) { x <- data.matrix(x) } + + ## Missing values + if (anyNA(x)) { + any.na <- TRUE + } else { + any.na <- FALSE + } ## Num threads ## Default 0 -> detect from system in C++. @@ -273,7 +280,7 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE, prediction.type, num.random.splits, sparse.x, use.sparse.data, order.snps, oob.error, max.depth, inbag, use.inbag, regularization.factor, use.regularization.factor, regularization.usedepth, - node.stats, time.interest, use.time.interest) + node.stats, time.interest, use.time.interest, any.na) if (length(result) == 0) { stop("User interrupt or internal error.") diff --git a/R/ranger.R b/R/ranger.R index b2ef21c65..41135560b 100644 --- a/R/ranger.R +++ b/R/ranger.R @@ -321,22 +321,25 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, } ## Check missing values + any.na <- FALSE if (na.action == "na.fail") { - if (any(is.na(x))) { + if (anyNA(x)) { offending_columns <- colnames(x)[colSums(is.na(x)) > 0] - stop("Missing data in columns: ", + stop("Error: Missing data in columns: ", paste0(offending_columns, collapse = ", "), ".", call. = FALSE) } - if (any(is.na(y))) { - stop("Missing data in dependent variable.", call. = FALSE) + if (anyNA(y)) { + stop("Error: Missing data in dependent variable.", call. = FALSE) } } else if (na.action == "na.omit") { # TODO: Implement na.omit stop("na.omit not implemented yet.") } else if (na.action == "na.learn") { - # TODO: fix y - if (any(is.na(y))) { - stop("Missing data in dependent variable.", call. = FALSE) + if (anyNA(y)) { + stop("Error: Missing data in dependent variable.", call. = FALSE) + } + if (anyNA(x)) { + any.na <- TRUE } } else { stop("Error: Invalid value for na.action. Use 'na.learn', 'na.omit' or 'na.fail'.") @@ -449,7 +452,6 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, levels.missing <- setdiff(levels(xx), levels.ordered) levels.ordered <- c(levels.missing, levels.ordered) } else if (is.factor(y) & nlevels(y) > 2) { - # TODO: Fix missings here levels.ordered <- pca.order(y = y, x = xx) } else { ## Order factor levels by mean response @@ -985,7 +987,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, num.random.splits, sparse.x, use.sparse.data, order.snps, oob.error, max.depth, inbag, use.inbag, regularization.factor, use.regularization.factor, regularization.usedepth, - node.stats, time.interest, use.time.interest) + node.stats, time.interest, use.time.interest, any.na) if (length(result) == 0) { stop("User interrupt or internal error.") diff --git a/src/Data.cpp b/src/Data.cpp index 52a5d091c..5c16e9364 100644 --- a/src/Data.cpp +++ b/src/Data.cpp @@ -22,7 +22,7 @@ namespace ranger { Data::Data() : num_rows(0), num_rows_rounded(0), num_cols(0), snp_data(0), num_cols_no_snp(0), externalData(true), index_data(0), max_num_unique_values( - 0), order_snps(false) { + 0), order_snps(false), any_na(false) { } size_t Data::getVariableID(const std::string& variable_name) const { @@ -214,7 +214,6 @@ bool Data::loadFromFileOther(std::ifstream& input_file, std::string header_line, } // #nocov end -// TODO: Check if slower void Data::getAllValues(std::vector& all_values, std::vector& sampleIDs, size_t varID, size_t start, size_t end) const { @@ -225,12 +224,18 @@ void Data::getAllValues(std::vector& all_values, std::vector& sa for (size_t pos = start; pos < end; ++pos) { all_values.push_back(get_x(sampleIDs[pos], varID)); } - std::sort(all_values.begin(), all_values.end(), less_nan); + if (any_na) { + std::sort(all_values.begin(), all_values.end(), less_nan); + } else { + std::sort(all_values.begin(), all_values.end()); + } all_values.erase(std::unique(all_values.begin(), all_values.end()), all_values.end()); // Keep only one NaN value - while (all_values.size() >= 2 && std::isnan(all_values[all_values.size() - 2])) { - all_values.pop_back(); + if (any_na) { + while (all_values.size() >= 2 && std::isnan(all_values[all_values.size() - 2])) { + all_values.pop_back(); + } } } else { // If GWA data just use 0, 1, 2 @@ -269,20 +274,29 @@ void Data::sort() { unique_values[row] = get_x(row, col); } - // TODO: Check if this makes it slower if no NA. If yes, check in the beginning if there is any NA and overload a function based on that? Or just use Inf from the beginning? - std::sort(unique_values.begin(), unique_values.end(), less_nan); + if (any_na) { + std::sort(unique_values.begin(), unique_values.end(), less_nan); + } else { + std::sort(unique_values.begin(), unique_values.end()); + } unique_values.erase(unique(unique_values.begin(), unique_values.end()), unique_values.end()); // Get index of unique value for (size_t row = 0; row < num_rows; ++row) { - size_t idx = std::lower_bound(unique_values.begin(), unique_values.end(), get_x(row, col), less_nan) - - unique_values.begin(); + size_t idx; + if (any_na) { + idx = std::lower_bound(unique_values.begin(), unique_values.end(), get_x(row, col), less_nan) - unique_values.begin(); + } else { + idx = std::lower_bound(unique_values.begin(), unique_values.end(), get_x(row, col)) - unique_values.begin(); + } index_data[col * num_rows + row] = idx; } // Save unique values (keep NaN) - while (unique_values.size() >= 2 && std::isnan(unique_values[unique_values.size() - 2])) { - unique_values.pop_back(); + if (any_na) { + while (unique_values.size() >= 2 && std::isnan(unique_values[unique_values.size() - 2])) { + unique_values.pop_back(); + } } unique_data_values.push_back(unique_values); if (unique_values.size() > max_num_unique_values) { diff --git a/src/Data.h b/src/Data.h index c58e5ec66..cf6c64656 100644 --- a/src/Data.h +++ b/src/Data.h @@ -185,6 +185,10 @@ class Data { } return varID; } + + const bool hasNA() const { + return any_na; + } // #nocov start (cannot be tested anymore because GenABEL not on CRAN) const std::vector>& getSnpOrder() const { @@ -221,6 +225,9 @@ class Data { // Order of 0/1/2 for ordered splitting std::vector> snp_order; bool order_snps; + + // Any missing values? + bool any_na; }; } // namespace ranger diff --git a/src/DataRcpp.h b/src/DataRcpp.h index ca21561cc..59c2b8661 100644 --- a/src/DataRcpp.h +++ b/src/DataRcpp.h @@ -39,13 +39,14 @@ namespace ranger { class DataRcpp: public Data { public: DataRcpp() = default; - DataRcpp(Rcpp::NumericMatrix& x, Rcpp::NumericMatrix& y, std::vector variable_names, size_t num_rows, size_t num_cols) { + DataRcpp(Rcpp::NumericMatrix& x, Rcpp::NumericMatrix& y, std::vector variable_names, size_t num_rows, size_t num_cols, bool any_na) { this->x = x; this->y = y; this->variable_names = variable_names; this->num_rows = num_rows; this->num_cols = num_cols; this->num_cols_no_snp = num_cols; + this->any_na = any_na; } DataRcpp(const DataRcpp&) = delete; diff --git a/src/DataSparse.cpp b/src/DataSparse.cpp index 779a54d6b..c240d3e6e 100644 --- a/src/DataSparse.cpp +++ b/src/DataSparse.cpp @@ -31,7 +31,7 @@ namespace ranger { DataSparse::DataSparse(Eigen::SparseMatrix& x, Rcpp::NumericMatrix& y, std::vector variable_names, size_t num_rows, - size_t num_cols) : + size_t num_cols, bool any_na) : x { }{ this->x.swap(x); this->y = y; @@ -39,6 +39,7 @@ DataSparse::DataSparse(Eigen::SparseMatrix& x, Rcpp::NumericMatrix& y, s this->num_rows = num_rows; this->num_cols = num_cols; this->num_cols_no_snp = num_cols; + this->any_na = any_na; } } // namespace ranger diff --git a/src/DataSparse.h b/src/DataSparse.h index 3cd904339..f52b017fb 100644 --- a/src/DataSparse.h +++ b/src/DataSparse.h @@ -41,7 +41,7 @@ class DataSparse: public Data { DataSparse() = default; DataSparse(Eigen::SparseMatrix& x, Rcpp::NumericMatrix& y, std::vector variable_names, size_t num_rows, - size_t num_cols); + size_t num_cols, bool any_na); DataSparse(const DataSparse&) = delete; DataSparse& operator=(const DataSparse&) = delete; diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 0c9ac2bfc..2f4597f96 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -13,8 +13,8 @@ Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); #endif // rangerCpp -Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericMatrix& input_y, std::vector variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, std::vector& min_node_size, std::vector& min_bucket, std::vector>& split_select_weights, bool use_split_select_weights, std::vector& always_split_variable_names, bool use_always_split_variable_names, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix snp_data, bool sample_with_replacement, bool probability, std::vector& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector& case_weights, bool use_case_weights, std::vector& class_weights, bool predict_all, bool keep_inbag, std::vector& sample_fraction, double alpha, double minprop, bool holdout, uint prediction_type_r, uint num_random_splits, Eigen::SparseMatrix& sparse_x, bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector>& inbag, bool use_inbag, std::vector& regularization_factor, bool use_regularization_factor, bool regularization_usedepth, bool node_stats, std::vector& time_interest, bool use_time_interest); -RcppExport SEXP _ranger_rangerCpp(SEXP treetypeSEXP, SEXP input_xSEXP, SEXP input_ySEXP, SEXP variable_namesSEXP, SEXP mtrySEXP, SEXP num_treesSEXP, SEXP verboseSEXP, SEXP seedSEXP, SEXP num_threadsSEXP, SEXP write_forestSEXP, SEXP importance_mode_rSEXP, SEXP min_node_sizeSEXP, SEXP min_bucketSEXP, SEXP split_select_weightsSEXP, SEXP use_split_select_weightsSEXP, SEXP always_split_variable_namesSEXP, SEXP use_always_split_variable_namesSEXP, SEXP prediction_modeSEXP, SEXP loaded_forestSEXP, SEXP snp_dataSEXP, SEXP sample_with_replacementSEXP, SEXP probabilitySEXP, SEXP unordered_variable_namesSEXP, SEXP use_unordered_variable_namesSEXP, SEXP save_memorySEXP, SEXP splitrule_rSEXP, SEXP case_weightsSEXP, SEXP use_case_weightsSEXP, SEXP class_weightsSEXP, SEXP predict_allSEXP, SEXP keep_inbagSEXP, SEXP sample_fractionSEXP, SEXP alphaSEXP, SEXP minpropSEXP, SEXP holdoutSEXP, SEXP prediction_type_rSEXP, SEXP num_random_splitsSEXP, SEXP sparse_xSEXP, SEXP use_sparse_dataSEXP, SEXP order_snpsSEXP, SEXP oob_errorSEXP, SEXP max_depthSEXP, SEXP inbagSEXP, SEXP use_inbagSEXP, SEXP regularization_factorSEXP, SEXP use_regularization_factorSEXP, SEXP regularization_usedepthSEXP, SEXP node_statsSEXP, SEXP time_interestSEXP, SEXP use_time_interestSEXP) { +Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericMatrix& input_y, std::vector variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, std::vector& min_node_size, std::vector& min_bucket, std::vector>& split_select_weights, bool use_split_select_weights, std::vector& always_split_variable_names, bool use_always_split_variable_names, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix snp_data, bool sample_with_replacement, bool probability, std::vector& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector& case_weights, bool use_case_weights, std::vector& class_weights, bool predict_all, bool keep_inbag, std::vector& sample_fraction, double alpha, double minprop, bool holdout, uint prediction_type_r, uint num_random_splits, Eigen::SparseMatrix& sparse_x, bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector>& inbag, bool use_inbag, std::vector& regularization_factor, bool use_regularization_factor, bool regularization_usedepth, bool node_stats, std::vector& time_interest, bool use_time_interest, bool any_na); +RcppExport SEXP _ranger_rangerCpp(SEXP treetypeSEXP, SEXP input_xSEXP, SEXP input_ySEXP, SEXP variable_namesSEXP, SEXP mtrySEXP, SEXP num_treesSEXP, SEXP verboseSEXP, SEXP seedSEXP, SEXP num_threadsSEXP, SEXP write_forestSEXP, SEXP importance_mode_rSEXP, SEXP min_node_sizeSEXP, SEXP min_bucketSEXP, SEXP split_select_weightsSEXP, SEXP use_split_select_weightsSEXP, SEXP always_split_variable_namesSEXP, SEXP use_always_split_variable_namesSEXP, SEXP prediction_modeSEXP, SEXP loaded_forestSEXP, SEXP snp_dataSEXP, SEXP sample_with_replacementSEXP, SEXP probabilitySEXP, SEXP unordered_variable_namesSEXP, SEXP use_unordered_variable_namesSEXP, SEXP save_memorySEXP, SEXP splitrule_rSEXP, SEXP case_weightsSEXP, SEXP use_case_weightsSEXP, SEXP class_weightsSEXP, SEXP predict_allSEXP, SEXP keep_inbagSEXP, SEXP sample_fractionSEXP, SEXP alphaSEXP, SEXP minpropSEXP, SEXP holdoutSEXP, SEXP prediction_type_rSEXP, SEXP num_random_splitsSEXP, SEXP sparse_xSEXP, SEXP use_sparse_dataSEXP, SEXP order_snpsSEXP, SEXP oob_errorSEXP, SEXP max_depthSEXP, SEXP inbagSEXP, SEXP use_inbagSEXP, SEXP regularization_factorSEXP, SEXP use_regularization_factorSEXP, SEXP regularization_usedepthSEXP, SEXP node_statsSEXP, SEXP time_interestSEXP, SEXP use_time_interestSEXP, SEXP any_naSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -68,7 +68,8 @@ BEGIN_RCPP Rcpp::traits::input_parameter< bool >::type node_stats(node_statsSEXP); Rcpp::traits::input_parameter< std::vector& >::type time_interest(time_interestSEXP); Rcpp::traits::input_parameter< bool >::type use_time_interest(use_time_interestSEXP); - rcpp_result_gen = Rcpp::wrap(rangerCpp(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest)); + Rcpp::traits::input_parameter< bool >::type any_na(any_naSEXP); + rcpp_result_gen = Rcpp::wrap(rangerCpp(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest, any_na)); return rcpp_result_gen; END_RCPP } @@ -147,7 +148,7 @@ END_RCPP } static const R_CallMethodDef CallEntries[] = { - {"_ranger_rangerCpp", (DL_FUNC) &_ranger_rangerCpp, 50}, + {"_ranger_rangerCpp", (DL_FUNC) &_ranger_rangerCpp, 51}, {"_ranger_numSmaller", (DL_FUNC) &_ranger_numSmaller, 2}, {"_ranger_randomObsNode", (DL_FUNC) &_ranger_randomObsNode, 3}, {"_ranger_hshrink_regr", (DL_FUNC) &_ranger_hshrink_regr, 10}, diff --git a/src/Tree.cpp b/src/Tree.cpp index 99e5d2277..11e10e9f3 100644 --- a/src/Tree.cpp +++ b/src/Tree.cpp @@ -54,7 +54,9 @@ void Tree::init(const Data* data, uint mtry, size_t num_samples, uint seed, std: // Create root node, assign bootstrap sample and oob samples child_nodeIDs.push_back(std::vector()); child_nodeIDs.push_back(std::vector()); - child_nodeIDs.push_back(std::vector()); + if (data->hasNA()) { + child_nodeIDs.push_back(std::vector()); + } createEmptyNode(); // Initialize random number generator and set seed @@ -149,7 +151,7 @@ void Tree::grow(std::vector* variable_importance) { } void Tree::predict(const Data* prediction_data, bool oob_prediction) { - + size_t num_samples_predict; if (oob_prediction) { num_samples_predict = num_samples_oob; @@ -159,6 +161,8 @@ void Tree::predict(const Data* prediction_data, bool oob_prediction) { prediction_terminal_nodeIDs.resize(num_samples_predict, 0); + bool any_na = prediction_data->hasNA(); + // For each sample start in root, drop down the tree and return final value for (size_t i = 0; i < num_samples_predict; ++i) { size_t sample_idx; @@ -180,8 +184,8 @@ void Tree::predict(const Data* prediction_data, bool oob_prediction) { double value = prediction_data->get_x(sample_idx, split_varID); if (prediction_data->isOrderedVariable(split_varID)) { - if (std::isnan(value)) { - if (child_nodeIDs[2][nodeID] > 0) { + if (any_na && std::isnan(value)) { + if (child_nodeIDs.size() >= 3 && child_nodeIDs[2][nodeID] > 0) { // Move to default child nodeID = child_nodeIDs[2][nodeID]; } else { @@ -324,6 +328,7 @@ bool Tree::splitNode(size_t nodeID) { std::vector possible_split_varIDs; createPossibleSplitVarSubset(possible_split_varIDs); + bool any_na = data->hasNA(); nan_go_right = false; // Call subclass method, sets split_varIDs and split_values @@ -357,7 +362,7 @@ bool Tree::splitNode(size_t nodeID) { while (pos < start_pos[right_child_nodeID]) { size_t sampleID = sampleIDs[pos]; - if (std::isnan(data->get_x(sampleID, split_varID))) { + if (any_na && std::isnan(data->get_x(sampleID, split_varID))) { if (nan_go_right) { // If going to right, move to right end --start_pos[right_child_nodeID]; @@ -413,7 +418,9 @@ void Tree::createEmptyNode() { split_values.push_back(0); child_nodeIDs[0].push_back(0); child_nodeIDs[1].push_back(0); - child_nodeIDs[2].push_back(0); + if (data->hasNA()) { + child_nodeIDs[2].push_back(0); + } start_pos.push_back(0); end_pos.push_back(0); diff --git a/src/rangerCpp.cpp b/src/rangerCpp.cpp index 66e34c0e5..0331c8c4f 100644 --- a/src/rangerCpp.cpp +++ b/src/rangerCpp.cpp @@ -62,7 +62,7 @@ Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericM bool use_sparse_data, bool order_snps, bool oob_error, uint max_depth, std::vector>& inbag, bool use_inbag, std::vector& regularization_factor, bool use_regularization_factor, bool regularization_usedepth, - bool node_stats, std::vector& time_interest, bool use_time_interest) { + bool node_stats, std::vector& time_interest, bool use_time_interest, bool any_na) { Rcpp::List result; @@ -112,9 +112,9 @@ Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericM // Initialize data if (use_sparse_data) { - data = std::make_unique(sparse_x, input_y, variable_names, num_rows, num_cols); + data = std::make_unique(sparse_x, input_y, variable_names, num_rows, num_cols, any_na); } else { - data = std::make_unique(input_x, input_y, variable_names, num_rows, num_cols); + data = std::make_unique(input_x, input_y, variable_names, num_rows, num_cols, any_na); } // If there is snp data, add it From db28116dfdcd8f2fa319bd47ff9e6d664bb07d02 Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Mon, 1 Jul 2024 16:43:38 +0200 Subject: [PATCH 03/11] tests for missing data --- tests/testthat/test_missings.R | 65 ++++++++++++++++++++++++++++++++++ tests/testthat/test_predict.R | 9 ----- tests/testthat/test_ranger.R | 24 +------------ 3 files changed, 66 insertions(+), 32 deletions(-) create mode 100644 tests/testthat/test_missings.R diff --git a/tests/testthat/test_missings.R b/tests/testthat/test_missings.R new file mode 100644 index 000000000..445e39e5f --- /dev/null +++ b/tests/testthat/test_missings.R @@ -0,0 +1,65 @@ +library(ranger) +library(survival) +context("ranger_unordered") + +test_that("Third child for missings only there if missings in data", { + rf1 <- ranger(Species ~ ., iris, num.trees = 5) + expect_length(rf1$forest$child.nodeIDs[[1]], 2) + + dat <- iris + dat[1, 1] <- NA + rf2 <- ranger(Species ~ ., dat, num.trees = 5) + expect_length(rf2$forest$child.nodeIDs[[1]], 3) +}) + +test_that("Training works with missing values in x but not in y", { + dat <- iris + dat[25, 1] <- NA + expect_silent(ranger(Species ~ ., dat, num.trees = 5)) + + dat <- iris + dat[4, 5] <- NA + expect_error(ranger(Species ~ ., dat, num.trees = 5), "Missing data in dependent variable.") +}) + +test_that("No error if missing value in irrelevant column, training", { + dat <- iris + dat[1, "Sepal.Width"] <- NA + expect_silent(ranger(Species ~ Sepal.Length, dat, num.trees = 5)) +}) + +test_that("No error if missing value in irrelevant column, prediction", { + rf <- ranger(Species ~ Sepal.Length, iris, num.trees = 5) + dat <- iris + dat[1, "Sepal.Width"] <- NA + expect_silent(predict(rf, dat)) +}) + +test_that("Prediction works with missing values", { + rf <- ranger(Species ~ ., iris, num.trees = 5, write.forest = TRUE) + + dat <- iris + dat[4, 4] <- NA + dat[25, 1] <- NA + expect_silent(predict(rf, dat)) +}) + +test_that("Order splitting working with missing values for classification", { + n <- 20 + dt <- data.frame(x = sample(c("A", "B", "C", "D", NA), n, replace = TRUE), + y = factor(rbinom(n, 1, 0.5)), + stringsAsFactors = FALSE) + + rf <- ranger(y ~ ., data = dt, num.trees = 5, min.node.size = n/2, respect.unordered.factors = 'order') + expect_true(all(rf$forest$is.ordered)) +}) + +test_that("Order splitting working with missing values for multiclass classification", { + n <- 20 + dt <- data.frame(x = sample(c("A", "B", "C", "D", NA), n, replace = TRUE), + y = factor(sample(c("A", "B", "C", "D"), n, replace = TRUE)), + stringsAsFactors = FALSE) + + rf <- ranger(y ~ ., data = dt, num.trees = 5, min.node.size = n/2, respect.unordered.factors = 'order') + expect_true(all(rf$forest$is.ordered)) +}) diff --git a/tests/testthat/test_predict.R b/tests/testthat/test_predict.R index 714da0696..5a60e89ea 100644 --- a/tests/testthat/test_predict.R +++ b/tests/testthat/test_predict.R @@ -38,15 +38,6 @@ test_that("Prediction works correctly if dependent variable is not first or last expect_gte(mean(predictions(predict(rf, dat[, -3])) == dat$Species), 0.9) }) -test_that("Prediction works with missing values", { - rf <- ranger(Species ~ ., iris, num.trees = 5, write.forest = TRUE) - - dat <- iris - dat[4, 4] <- NA - dat[25, 1] <- NA - expect_silent(predict(rf, dat)) -}) - test_that("If num.trees set, these number is used for predictions", { rf <- ranger(Species ~ ., iris, num.trees = 5, write.forest = TRUE) pred <- predict(rf, iris, predict.all = TRUE, num.trees = 3) diff --git a/tests/testthat/test_ranger.R b/tests/testthat/test_ranger.R index 232bc11ea..73f7a63ae 100644 --- a/tests/testthat/test_ranger.R +++ b/tests/testthat/test_ranger.R @@ -180,29 +180,6 @@ test_that("OOB error is correct for 1 tree, regression", { expect_equal(rf$prediction.error, mean((dat$y - rf$predictions)^2, na.rm = TRUE)) }) -test_that("Training works with missing values in x but not in y", { - dat <- iris - dat[25, 1] <- NA - expect_silent(ranger(Species ~ ., dat, num.trees = 5)) - - dat <- iris - dat[4, 5] <- NA - expect_error(ranger(Species ~ ., dat, num.trees = 5), "Missing data in dependent variable.") -}) - -test_that("No error if missing value in irrelevant column, training", { - dat <- iris - dat[1, "Sepal.Width"] <- NA - expect_silent(ranger(Species ~ Sepal.Length, dat, num.trees = 5)) -}) - -test_that("No error if missing value in irrelevant column, prediction", { - rf <- ranger(Species ~ Sepal.Length, iris, num.trees = 5) - dat <- iris - dat[1, "Sepal.Width"] <- NA - expect_silent(predict(rf, dat)) -}) - test_that("Split points are at (A+B)/2 for numeric features, regression variance splitting", { dat <- data.frame(y = rbinom(100, 1, .5), x = rbinom(100, 1, .5)) rf <- ranger(y ~ x, dat, num.trees = 10) @@ -480,3 +457,4 @@ test_that("Vector min.bucket creates nodes of correct size", { }) + From 1d1bb34b105510441ffcb353327ad022b2c50d07 Mon Sep 17 00:00:00 2001 From: "Marvin N. Wright" Date: Tue, 2 Jul 2024 11:10:59 +0200 Subject: [PATCH 04/11] put NA splitting in seperate functions --- src/TreeClassification.cpp | 417 ++++++++++++++++++++++++++++--------- src/TreeClassification.h | 11 + 2 files changed, 332 insertions(+), 96 deletions(-) diff --git a/src/TreeClassification.cpp b/src/TreeClassification.cpp index c4b5437d4..0b3511269 100644 --- a/src/TreeClassification.cpp +++ b/src/TreeClassification.cpp @@ -204,11 +204,22 @@ bool TreeClassification::findBestSplit(size_t nodeID, std::vector& possi // Use faster method for both cases double q = (double) num_samples_node / (double) data->getNumUniqueDataValues(varID); if (q < Q_THRESHOLD) { - findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, - best_decrease); + if (data->hasNA()) { + findBestSplitValueNanSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } else { + findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } + } else { - findBestSplitValueLargeQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, - best_decrease); + if (data->hasNA()) { + findBestSplitValueNanLargeQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } else { + findBestSplitValueLargeQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } } } } else { @@ -273,36 +284,15 @@ void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, s double& best_decrease, const std::vector& possible_split_values, std::vector& counter_per_class, std::vector& counter) { - // Counters without NaNs - std::vector class_counts_nan(num_classes, 0); - size_t num_samples_node_nan = 0; - size_t last_index = possible_split_values.size() - 1; - if (std::isnan(possible_split_values[last_index])) { - for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { - size_t sampleID = sampleIDs[pos]; - uint sample_classID = (*response_classIDs)[sampleID]; - - if (std::isnan(data->get_x(sampleID, varID))) { - ++num_samples_node_nan; - ++class_counts_nan[sample_classID]; - } else { - size_t idx = std::lower_bound(possible_split_values.begin(), possible_split_values.end(), - data->get_x(sampleID, varID)) - possible_split_values.begin(); - ++counter_per_class[idx * num_classes + sample_classID]; - ++counter[idx]; - } - } - } else { - for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { - size_t sampleID = sampleIDs[pos]; - uint sample_classID = (*response_classIDs)[sampleID]; - size_t idx = std::lower_bound(possible_split_values.begin(), possible_split_values.end(), - data->get_x(sampleID, varID)) - possible_split_values.begin(); - - ++counter_per_class[idx * num_classes + sample_classID]; - ++counter[idx]; - } + for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { + size_t sampleID = sampleIDs[pos]; + uint sample_classID = (*response_classIDs)[sampleID]; + size_t idx = std::lower_bound(possible_split_values.begin(), possible_split_values.end(), + data->get_x(sampleID, varID)) - possible_split_values.begin(); + + ++counter_per_class[idx * num_classes + sample_classID]; + ++counter[idx]; } size_t n_left = 0; @@ -310,7 +300,7 @@ void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, s // Compute decrease of impurity for each split for (size_t i = 0; i < possible_split_values.size() - 1; ++i) { - + // Stop if nothing here if (counter[i] == 0) { continue; @@ -319,7 +309,7 @@ void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, s n_left += counter[i]; // Stop if right child empty - size_t n_right = num_samples_node - num_samples_node_nan - n_left; + size_t n_right = num_samples_node - n_left; if (n_right == 0) { break; } @@ -330,8 +320,6 @@ void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, s } double decrease; - double decrease_nanleft; - double decrease_nanright; if (splitrule == HELLINGER) { for (size_t j = 0; j < num_classes; ++j) { class_counts_left[j] += counter_per_class[i * num_classes + j]; @@ -350,23 +338,16 @@ void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, s // Sum of squares double sum_left = 0; double sum_right = 0; - double sum_left_withnan = 0; - double sum_right_withnan = 0; for (size_t j = 0; j < num_classes; ++j) { class_counts_left[j] += counter_per_class[i * num_classes + j]; - size_t class_count_right = class_counts[j] - class_counts_nan[j] - class_counts_left[j]; + size_t class_count_right = class_counts[j] - class_counts_left[j]; sum_left += (*class_weights)[j] * class_counts_left[j] * class_counts_left[j]; sum_right += (*class_weights)[j] * class_count_right * class_count_right; - - sum_left_withnan += (*class_weights)[j] * (class_counts_left[j] + class_counts_nan[j]) * (class_counts_left[j] + class_counts_nan[j]); - sum_right_withnan += (*class_weights)[j] * (class_count_right + class_counts_nan[j]) * (class_count_right + class_counts_nan[j]); } // Decrease of impurity decrease = sum_right / (double) n_right + sum_left / (double) n_left; - decrease_nanleft = sum_right / (double) n_right + sum_left_withnan / (double) (n_left + num_samples_node_nan); - decrease_nanright = sum_right_withnan / (double) (n_right + num_samples_node_nan) + sum_left / (double) n_left; } // Stop if class-wise minimal bucket size reached @@ -394,12 +375,6 @@ void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, s best_varID = varID; best_decrease = decrease; - if (decrease_nanright > decrease_nanleft) { - nan_go_right = true; - } else { - nan_go_right = false; - } - // Use smaller value if average is numerically the same as the larger value if (best_value == possible_split_values[i + 1]) { best_value = possible_split_values[i]; @@ -416,36 +391,16 @@ void TreeClassification::findBestSplitValueLargeQ(size_t nodeID, size_t varID, s size_t num_unique = data->getNumUniqueDataValues(varID); std::fill_n(counter_per_class.begin(), num_unique * num_classes, 0); std::fill_n(counter.begin(), num_unique, 0); - - // Counters without NaNs - std::vector class_counts_nan(num_classes, 0); - size_t num_samples_node_nan = 0; // Count values size_t last_index = data->getNumUniqueDataValues(varID) - 1; - if (std::isnan(data->getUniqueDataValue(varID, last_index))) { - for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { - size_t sampleID = sampleIDs[pos]; - size_t index = data->getIndex(sampleID, varID); - size_t classID = (*response_classIDs)[sampleID]; - - if (index < last_index) { - ++counter[index]; - ++counter_per_class[index * num_classes + classID]; - } else { - ++num_samples_node_nan; - ++class_counts_nan[classID]; - } - } - } else { - for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { - size_t sampleID = sampleIDs[pos]; - size_t index = data->getIndex(sampleID, varID); - size_t classID = (*response_classIDs)[sampleID]; - - ++counter[index]; - ++counter_per_class[index * num_classes + classID]; - } + for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { + size_t sampleID = sampleIDs[pos]; + size_t index = data->getIndex(sampleID, varID); + size_t classID = (*response_classIDs)[sampleID]; + + ++counter[index]; + ++counter_per_class[index * num_classes + classID]; } size_t n_left = 0; @@ -462,7 +417,7 @@ void TreeClassification::findBestSplitValueLargeQ(size_t nodeID, size_t varID, s n_left += counter[i]; // Stop if right child empty - size_t n_right = num_samples_node - num_samples_node_nan - n_left; + size_t n_right = num_samples_node - n_left; if (n_right == 0) { break; } @@ -473,8 +428,6 @@ void TreeClassification::findBestSplitValueLargeQ(size_t nodeID, size_t varID, s } double decrease; - double decrease_nanleft; - double decrease_nanright; if (splitrule == HELLINGER) { for (size_t j = 0; j < num_classes; ++j) { class_counts_left[j] += counter_per_class[i * num_classes + j]; @@ -493,23 +446,16 @@ void TreeClassification::findBestSplitValueLargeQ(size_t nodeID, size_t varID, s // Sum of squares double sum_left = 0; double sum_right = 0; - double sum_left_withnan = 0; - double sum_right_withnan = 0; for (size_t j = 0; j < num_classes; ++j) { class_counts_left[j] += counter_per_class[i * num_classes + j]; - size_t class_count_right = class_counts[j] - class_counts_nan[j] - class_counts_left[j]; + size_t class_count_right = class_counts[j] - class_counts_left[j]; sum_left += (*class_weights)[j] * class_counts_left[j] * class_counts_left[j]; sum_right += (*class_weights)[j] * class_count_right * class_count_right; - - sum_left_withnan += (*class_weights)[j] * (class_counts_left[j] + class_counts_nan[j]) * (class_counts_left[j] + class_counts_nan[j]); - sum_right_withnan += (*class_weights)[j] * (class_count_right + class_counts_nan[j]) * (class_count_right + class_counts_nan[j]); } // Decrease of impurity decrease = sum_right / (double) n_right + sum_left / (double) n_left; - decrease_nanleft = sum_right / (double) n_right + sum_left_withnan / (double) (n_left + num_samples_node_nan); - decrease_nanright = sum_right_withnan / (double) (n_right + num_samples_node_nan) + sum_left / (double) n_left; } // Stop if class-wise minimal bucket size reached @@ -542,13 +488,7 @@ void TreeClassification::findBestSplitValueLargeQ(size_t nodeID, size_t varID, s best_value = (data->getUniqueDataValue(varID, i) + data->getUniqueDataValue(varID, j)) / 2; best_varID = varID; best_decrease = decrease; - - if (decrease_nanright > decrease_nanleft) { - nan_go_right = true; - } else { - nan_go_right = false; - } - + // Use smaller value if average is numerically the same as the larger value if (best_value == data->getUniqueDataValue(varID, j)) { best_value = data->getUniqueDataValue(varID, i); @@ -986,6 +926,291 @@ void TreeClassification::findBestSplitValueExtraTreesUnordered(size_t nodeID, si } } +void TreeClassification::findBestSplitValueNanSmallQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease) { + + // Create possible split values + std::vector possible_split_values; + data->getAllValues(possible_split_values, sampleIDs, varID, start_pos[nodeID], end_pos[nodeID]); + + // Try next variable if all equal for this + if (possible_split_values.size() < 2) { + return; + } + + const size_t num_splits = possible_split_values.size(); + if (memory_saving_splitting) { + std::vector class_counts_right(num_splits * num_classes), n_right(num_splits); + findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease, possible_split_values, class_counts_right, n_right); + } else { + std::fill_n(counter_per_class.begin(), num_splits * num_classes, 0); + std::fill_n(counter.begin(), num_splits, 0); + findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease, possible_split_values, counter_per_class, counter); + } +} + +void TreeClassification::findBestSplitValueNanSmallQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease, const std::vector& possible_split_values, std::vector& counter_per_class, + std::vector& counter) { + + // Counters without NaNs + std::vector class_counts_nan(num_classes, 0); + size_t num_samples_node_nan = 0; + + size_t last_index = possible_split_values.size() - 1; + if (std::isnan(possible_split_values[last_index])) { + for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { + size_t sampleID = sampleIDs[pos]; + uint sample_classID = (*response_classIDs)[sampleID]; + + if (std::isnan(data->get_x(sampleID, varID))) { + ++num_samples_node_nan; + ++class_counts_nan[sample_classID]; + } else { + size_t idx = std::lower_bound(possible_split_values.begin(), possible_split_values.end(), + data->get_x(sampleID, varID)) - possible_split_values.begin(); + ++counter_per_class[idx * num_classes + sample_classID]; + ++counter[idx]; + } + } + } else { + for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { + size_t sampleID = sampleIDs[pos]; + uint sample_classID = (*response_classIDs)[sampleID]; + size_t idx = std::lower_bound(possible_split_values.begin(), possible_split_values.end(), + data->get_x(sampleID, varID)) - possible_split_values.begin(); + + ++counter_per_class[idx * num_classes + sample_classID]; + ++counter[idx]; + } + } + + size_t n_left = 0; + std::vector class_counts_left(num_classes); + + // Compute decrease of impurity for each split + for (size_t i = 0; i < possible_split_values.size() - 1; ++i) { + + // Stop if nothing here + if (counter[i] == 0) { + continue; + } + + n_left += counter[i]; + + // Stop if right child empty + size_t n_right = num_samples_node - num_samples_node_nan - n_left; + if (n_right == 0) { + break; + } + + // Stop if minimal bucket size reached + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { + continue; + } + + double decrease; + double decrease_nanleft; + double decrease_nanright; + + // Sum of squares + double sum_left = 0; + double sum_right = 0; + double sum_left_withnan = 0; + double sum_right_withnan = 0; + for (size_t j = 0; j < num_classes; ++j) { + class_counts_left[j] += counter_per_class[i * num_classes + j]; + size_t class_count_right = class_counts[j] - class_counts_nan[j] - class_counts_left[j]; + + sum_left += (*class_weights)[j] * class_counts_left[j] * class_counts_left[j]; + sum_right += (*class_weights)[j] * class_count_right * class_count_right; + + sum_left_withnan += (*class_weights)[j] * (class_counts_left[j] + class_counts_nan[j]) * (class_counts_left[j] + class_counts_nan[j]); + sum_right_withnan += (*class_weights)[j] * (class_count_right + class_counts_nan[j]) * (class_count_right + class_counts_nan[j]); + } + + // Decrease of impurity + decrease = sum_right / (double) n_right + sum_left / (double) n_left; + decrease_nanleft = sum_right / (double) n_right + sum_left_withnan / (double) (n_left + num_samples_node_nan); + decrease_nanright = sum_right_withnan / (double) (n_right + num_samples_node_nan) + sum_left / (double) n_left; + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts[j] - class_counts_left[j]; + if (class_counts_left[j] < (*min_bucket)[j] || class_count_right < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } + + // Regularization + regularize(decrease, varID); + + // If better than before, use this + if (decrease > best_decrease) { + // Use mid-point split + best_value = (possible_split_values[i] + possible_split_values[i + 1]) / 2; + best_varID = varID; + best_decrease = decrease; + + if (decrease_nanright > decrease_nanleft) { + nan_go_right = true; + } else { + nan_go_right = false; + } + + // Use smaller value if average is numerically the same as the larger value + if (best_value == possible_split_values[i + 1]) { + best_value = possible_split_values[i]; + } + } + } +} + +void TreeClassification::findBestSplitValueNanLargeQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease) { + + // Set counters to 0 + size_t num_unique = data->getNumUniqueDataValues(varID); + std::fill_n(counter_per_class.begin(), num_unique * num_classes, 0); + std::fill_n(counter.begin(), num_unique, 0); + + // Counters without NaNs + std::vector class_counts_nan(num_classes, 0); + size_t num_samples_node_nan = 0; + + // Count values + size_t last_index = data->getNumUniqueDataValues(varID) - 1; + if (std::isnan(data->getUniqueDataValue(varID, last_index))) { + for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { + size_t sampleID = sampleIDs[pos]; + size_t index = data->getIndex(sampleID, varID); + size_t classID = (*response_classIDs)[sampleID]; + + if (index < last_index) { + ++counter[index]; + ++counter_per_class[index * num_classes + classID]; + } else { + ++num_samples_node_nan; + ++class_counts_nan[classID]; + } + } + } else { + for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { + size_t sampleID = sampleIDs[pos]; + size_t index = data->getIndex(sampleID, varID); + size_t classID = (*response_classIDs)[sampleID]; + + ++counter[index]; + ++counter_per_class[index * num_classes + classID]; + } + } + + size_t n_left = 0; + std::vector class_counts_left(num_classes); + + // Compute decrease of impurity for each split + for (size_t i = 0; i < num_unique - 1; ++i) { + + // Stop if nothing here + if (counter[i] == 0) { + continue; + } + + n_left += counter[i]; + + // Stop if right child empty + size_t n_right = num_samples_node - num_samples_node_nan - n_left; + if (n_right == 0) { + break; + } + + // Stop if minimal bucket size reached + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { + continue; + } + + double decrease; + double decrease_nanleft; + double decrease_nanright; + + // Sum of squares + double sum_left = 0; + double sum_right = 0; + double sum_left_withnan = 0; + double sum_right_withnan = 0; + for (size_t j = 0; j < num_classes; ++j) { + class_counts_left[j] += counter_per_class[i * num_classes + j]; + size_t class_count_right = class_counts[j] - class_counts_nan[j] - class_counts_left[j]; + + sum_left += (*class_weights)[j] * class_counts_left[j] * class_counts_left[j]; + sum_right += (*class_weights)[j] * class_count_right * class_count_right; + + sum_left_withnan += (*class_weights)[j] * (class_counts_left[j] + class_counts_nan[j]) * (class_counts_left[j] + class_counts_nan[j]); + sum_right_withnan += (*class_weights)[j] * (class_count_right + class_counts_nan[j]) * (class_count_right + class_counts_nan[j]); + } + + // Decrease of impurity + decrease = sum_right / (double) n_right + sum_left / (double) n_left; + decrease_nanleft = sum_right / (double) n_right + sum_left_withnan / (double) (n_left + num_samples_node_nan); + decrease_nanright = sum_right_withnan / (double) (n_right + num_samples_node_nan) + sum_left / (double) n_left; + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts[j] - class_counts_left[j]; + if (class_counts_left[j] < (*min_bucket)[j] || class_count_right < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } + + // Regularization + regularize(decrease, varID); + + // If better than before, use this + if (decrease > best_decrease) { + // Find next value in this node + size_t j = i + 1; + while (j < num_unique && counter[j] == 0) { + ++j; + } + + // Use mid-point split + best_value = (data->getUniqueDataValue(varID, i) + data->getUniqueDataValue(varID, j)) / 2; + best_varID = varID; + best_decrease = decrease; + + if (decrease_nanright > decrease_nanleft) { + nan_go_right = true; + } else { + nan_go_right = false; + } + + // Use smaller value if average is numerically the same as the larger value + if (best_value == data->getUniqueDataValue(varID, j)) { + best_value = data->getUniqueDataValue(varID, i); + } + } + } +} + void TreeClassification::addGiniImportance(size_t nodeID, size_t varID, double decrease) { double best_decrease = decrease; diff --git a/src/TreeClassification.h b/src/TreeClassification.h index 534c121a0..82a491ad6 100644 --- a/src/TreeClassification.h +++ b/src/TreeClassification.h @@ -81,6 +81,17 @@ class TreeClassification: public Tree { void findBestSplitValueExtraTreesUnordered(size_t nodeID, size_t varID, size_t num_classes, const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, double& best_decrease); + + void findBestSplitValueNanSmallQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease); + void findBestSplitValueNanSmallQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease, const std::vector& possible_split_values, std::vector& counter_per_class, + std::vector& counter); + void findBestSplitValueNanLargeQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease); void addGiniImportance(size_t nodeID, size_t varID, double decrease); From 71558cbf4316ec839c3475916b78239e749e7f43 Mon Sep 17 00:00:00 2001 From: "Marvin N. Wright" Date: Tue, 2 Jul 2024 11:15:24 +0200 Subject: [PATCH 05/11] NA splitting only for certain split rules --- R/ranger.R | 3 +++ 1 file changed, 3 insertions(+) diff --git a/R/ranger.R b/R/ranger.R index 41135560b..5a7c30b83 100644 --- a/R/ranger.R +++ b/R/ranger.R @@ -340,6 +340,9 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, } if (anyNA(x)) { any.na <- TRUE + if (!(splitrule %in% c("gini", "variance", "logrank"))) { + stop("Error: Missing value handling currently only implemented for gini, variance and logrank splitrules.") + } } } else { stop("Error: Invalid value for na.action. Use 'na.learn', 'na.omit' or 'na.fail'.") From 47ec4c50ef3c9911bd2bace3e2e539385503dc10 Mon Sep 17 00:00:00 2001 From: "Marvin N. Wright" Date: Tue, 2 Jul 2024 11:29:00 +0200 Subject: [PATCH 06/11] NA splitting only for certain split rules --- R/ranger.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/ranger.R b/R/ranger.R index 5a7c30b83..adde5fcb2 100644 --- a/R/ranger.R +++ b/R/ranger.R @@ -340,7 +340,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, } if (anyNA(x)) { any.na <- TRUE - if (!(splitrule %in% c("gini", "variance", "logrank"))) { + if (!is.null(splitrule) && !(splitrule %in% c("gini", "variance", "logrank"))) { stop("Error: Missing value handling currently only implemented for gini, variance and logrank splitrules.") } } From 6510c76be833a4ca2ae36ad2f4225e13b1fd2242 Mon Sep 17 00:00:00 2001 From: "Marvin N. Wright" Date: Tue, 2 Jul 2024 11:40:44 +0200 Subject: [PATCH 07/11] revert small changes --- src/TreeClassification.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/TreeClassification.cpp b/src/TreeClassification.cpp index 0b3511269..a57de083e 100644 --- a/src/TreeClassification.cpp +++ b/src/TreeClassification.cpp @@ -284,13 +284,12 @@ void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, s double& best_decrease, const std::vector& possible_split_values, std::vector& counter_per_class, std::vector& counter) { - size_t last_index = possible_split_values.size() - 1; for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { size_t sampleID = sampleIDs[pos]; uint sample_classID = (*response_classIDs)[sampleID]; size_t idx = std::lower_bound(possible_split_values.begin(), possible_split_values.end(), - data->get_x(sampleID, varID)) - possible_split_values.begin(); - + data->get_x(sampleID, varID)) - possible_split_values.begin(); + ++counter_per_class[idx * num_classes + sample_classID]; ++counter[idx]; } @@ -374,7 +373,7 @@ void TreeClassification::findBestSplitValueSmallQ(size_t nodeID, size_t varID, s best_value = (possible_split_values[i] + possible_split_values[i + 1]) / 2; best_varID = varID; best_decrease = decrease; - + // Use smaller value if average is numerically the same as the larger value if (best_value == possible_split_values[i + 1]) { best_value = possible_split_values[i]; @@ -393,12 +392,11 @@ void TreeClassification::findBestSplitValueLargeQ(size_t nodeID, size_t varID, s std::fill_n(counter.begin(), num_unique, 0); // Count values - size_t last_index = data->getNumUniqueDataValues(varID) - 1; for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { size_t sampleID = sampleIDs[pos]; size_t index = data->getIndex(sampleID, varID); size_t classID = (*response_classIDs)[sampleID]; - + ++counter[index]; ++counter_per_class[index * num_classes + classID]; } @@ -475,7 +473,7 @@ void TreeClassification::findBestSplitValueLargeQ(size_t nodeID, size_t varID, s // Regularization regularize(decrease, varID); - + // If better than before, use this if (decrease > best_decrease) { // Find next value in this node From 0ab1efb9b13ffd3d803fa29277bd999e8ca19f47 Mon Sep 17 00:00:00 2001 From: "Marvin N. Wright" Date: Tue, 2 Jul 2024 13:41:08 +0200 Subject: [PATCH 08/11] add missing value handling for probability and regression --- src/TreeProbability.cpp | 305 +++++++++++++++++++++++++++++++++++++++- src/TreeProbability.h | 11 ++ src/TreeRegression.cpp | 226 ++++++++++++++++++++++++++++- src/TreeRegression.h | 8 ++ 4 files changed, 544 insertions(+), 6 deletions(-) diff --git a/src/TreeProbability.cpp b/src/TreeProbability.cpp index c92a57137..33cee29fd 100644 --- a/src/TreeProbability.cpp +++ b/src/TreeProbability.cpp @@ -208,11 +208,21 @@ bool TreeProbability::findBestSplit(size_t nodeID, std::vector& possible // Use faster method for both cases double q = (double) num_samples_node / (double) data->getNumUniqueDataValues(varID); if (q < Q_THRESHOLD) { - findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, - best_decrease); + if (data->hasNA()) { + findBestSplitValueNanSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } else { + findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } } else { - findBestSplitValueLargeQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, - best_decrease); + if (data->hasNA()) { + findBestSplitValueNanLargeQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } else { + findBestSplitValueLargeQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease); + } } } } else { @@ -917,6 +927,293 @@ void TreeProbability::findBestSplitValueExtraTreesUnordered(size_t nodeID, size_ } } +void TreeProbability::findBestSplitValueNanSmallQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease) { + + // Create possible split values + std::vector possible_split_values; + data->getAllValues(possible_split_values, sampleIDs, varID, start_pos[nodeID], end_pos[nodeID]); + + // Try next variable if all equal for this + if (possible_split_values.size() < 2) { + return; + } + + const size_t num_splits = possible_split_values.size(); + if (memory_saving_splitting) { + std::vector class_counts_right(num_splits * num_classes), n_right(num_splits); + findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease, possible_split_values, class_counts_right, n_right); + } else { + std::fill_n(counter_per_class.begin(), num_splits * num_classes, 0); + std::fill_n(counter.begin(), num_splits, 0); + findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, + best_decrease, possible_split_values, counter_per_class, counter); + } +} + +void TreeProbability::findBestSplitValueNanSmallQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease, const std::vector& possible_split_values, std::vector& counter_per_class, + std::vector& counter) { + + // Counters without NaNs + std::vector class_counts_nan(num_classes, 0); + size_t num_samples_node_nan = 0; + + size_t last_index = possible_split_values.size() - 1; + if (std::isnan(possible_split_values[last_index])) { + for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { + size_t sampleID = sampleIDs[pos]; + uint sample_classID = (*response_classIDs)[sampleID]; + + if (std::isnan(data->get_x(sampleID, varID))) { + ++num_samples_node_nan; + ++class_counts_nan[sample_classID]; + } else { + size_t idx = std::lower_bound(possible_split_values.begin(), possible_split_values.end(), + data->get_x(sampleID, varID)) - possible_split_values.begin(); + ++counter_per_class[idx * num_classes + sample_classID]; + ++counter[idx]; + } + } + } else { + for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { + size_t sampleID = sampleIDs[pos]; + uint sample_classID = (*response_classIDs)[sampleID]; + size_t idx = std::lower_bound(possible_split_values.begin(), possible_split_values.end(), + data->get_x(sampleID, varID)) - possible_split_values.begin(); + + ++counter_per_class[idx * num_classes + sample_classID]; + ++counter[idx]; + } + } + + size_t n_left = 0; + std::vector class_counts_left(num_classes); + + // Compute decrease of impurity for each split + for (size_t i = 0; i < possible_split_values.size() - 1; ++i) { + + // Stop if nothing here + if (counter[i] == 0) { + continue; + } + + n_left += counter[i]; + + // Stop if right child empty + size_t n_right = num_samples_node - num_samples_node_nan - n_left; + if (n_right == 0) { + break; + } + + // Stop if minimal bucket size reached + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { + continue; + } + + double decrease; + double decrease_nanleft; + double decrease_nanright; + + // Sum of squares + double sum_left = 0; + double sum_right = 0; + double sum_left_withnan = 0; + double sum_right_withnan = 0; + for (size_t j = 0; j < num_classes; ++j) { + class_counts_left[j] += counter_per_class[i * num_classes + j]; + size_t class_count_right = class_counts[j] - class_counts_nan[j] - class_counts_left[j]; + + sum_left += (*class_weights)[j] * class_counts_left[j] * class_counts_left[j]; + sum_right += (*class_weights)[j] * class_count_right * class_count_right; + + sum_left_withnan += (*class_weights)[j] * (class_counts_left[j] + class_counts_nan[j]) * (class_counts_left[j] + class_counts_nan[j]); + sum_right_withnan += (*class_weights)[j] * (class_count_right + class_counts_nan[j]) * (class_count_right + class_counts_nan[j]); + } + + // Decrease of impurity + decrease = sum_right / (double) n_right + sum_left / (double) n_left; + decrease_nanleft = sum_right / (double) n_right + sum_left_withnan / (double) (n_left + num_samples_node_nan); + decrease_nanright = sum_right_withnan / (double) (n_right + num_samples_node_nan) + sum_left / (double) n_left; + + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts[j] - class_counts_left[j]; + if (class_counts_left[j] < (*min_bucket)[j] || class_count_right < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } + + // Regularization + regularize(decrease, varID); + + // If better than before, use this + if (decrease > best_decrease) { + // Use mid-point split + best_value = (possible_split_values[i] + possible_split_values[i + 1]) / 2; + best_varID = varID; + best_decrease = decrease; + + if (decrease_nanright > decrease_nanleft) { + nan_go_right = true; + } else { + nan_go_right = false; + } + + // Use smaller value if average is numerically the same as the larger value + if (best_value == possible_split_values[i + 1]) { + best_value = possible_split_values[i]; + } + } + } +} + +void TreeProbability::findBestSplitValueNanLargeQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease) { + + // Set counters to 0 + size_t num_unique = data->getNumUniqueDataValues(varID); + std::fill_n(counter_per_class.begin(), num_unique * num_classes, 0); + std::fill_n(counter.begin(), num_unique, 0); + + // Counters without NaNs + std::vector class_counts_nan(num_classes, 0); + size_t num_samples_node_nan = 0; + + // Count values + size_t last_index = data->getNumUniqueDataValues(varID) - 1; + if (std::isnan(data->getUniqueDataValue(varID, last_index))) { + for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { + size_t sampleID = sampleIDs[pos]; + size_t index = data->getIndex(sampleID, varID); + size_t classID = (*response_classIDs)[sampleID]; + + if (index < last_index) { + ++counter[index]; + ++counter_per_class[index * num_classes + classID]; + } else { + ++num_samples_node_nan; + ++class_counts_nan[classID]; + } + } + } else { + for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { + size_t sampleID = sampleIDs[pos]; + size_t index = data->getIndex(sampleID, varID); + size_t classID = (*response_classIDs)[sampleID]; + + ++counter[index]; + ++counter_per_class[index * num_classes + classID]; + } + } + + + size_t n_left = 0; + std::vector class_counts_left(num_classes); + + // Compute decrease of impurity for each split + for (size_t i = 0; i < num_unique - 1; ++i) { + + // Stop if nothing here + if (counter[i] == 0) { + continue; + } + + n_left += counter[i]; + + // Stop if right child empty + size_t n_right = num_samples_node - num_samples_node_nan - n_left; + if (n_right == 0) { + break; + } + + // Stop if minimal bucket size reached + if (min_bucket->size() == 1 && (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0])) { + continue; + } + + double decrease; + double decrease_nanleft; + double decrease_nanright; + + // Sum of squares + double sum_left = 0; + double sum_right = 0; + double sum_left_withnan = 0; + double sum_right_withnan = 0; + for (size_t j = 0; j < num_classes; ++j) { + class_counts_left[j] += counter_per_class[i * num_classes + j]; + size_t class_count_right = class_counts[j] - class_counts_nan[j] - class_counts_left[j]; + + sum_left += (*class_weights)[j] * class_counts_left[j] * class_counts_left[j]; + sum_right += (*class_weights)[j] * class_count_right * class_count_right; + + sum_left_withnan += (*class_weights)[j] * (class_counts_left[j] + class_counts_nan[j]) * (class_counts_left[j] + class_counts_nan[j]); + sum_right_withnan += (*class_weights)[j] * (class_count_right + class_counts_nan[j]) * (class_count_right + class_counts_nan[j]); + } + + // Decrease of impurity + decrease = sum_right / (double) n_right + sum_left / (double) n_left; + decrease_nanleft = sum_right / (double) n_right + sum_left_withnan / (double) (n_left + num_samples_node_nan); + decrease_nanright = sum_right_withnan / (double) (n_right + num_samples_node_nan) + sum_left / (double) n_left; + + // Stop if class-wise minimal bucket size reached + if (min_bucket->size() > 1) { + bool stop = false; + for (size_t j = 0; j < num_classes; ++j) { + size_t class_count_right = class_counts[j] - class_counts_left[j]; + if (class_counts_left[j] < (*min_bucket)[j] || class_count_right < (*min_bucket)[j]) { + stop = true; + break; + } + } + if (stop) { + continue; + } + } + + // Regularization + regularize(decrease, varID); + + // If better than before, use this + if (decrease > best_decrease) { + // Find next value in this node + size_t j = i + 1; + while (j < num_unique && counter[j] == 0) { + ++j; + } + + // Use mid-point split + best_value = (data->getUniqueDataValue(varID, i) + data->getUniqueDataValue(varID, j)) / 2; + best_varID = varID; + best_decrease = decrease; + + if (decrease_nanright > decrease_nanleft) { + nan_go_right = true; + } else { + nan_go_right = false; + } + + // Use smaller value if average is numerically the same as the larger value + if (best_value == data->getUniqueDataValue(varID, j)) { + best_value = data->getUniqueDataValue(varID, i); + } + } + } +} + void TreeProbability::addImpurityImportance(size_t nodeID, size_t varID, double decrease) { double best_decrease = decrease; diff --git a/src/TreeProbability.h b/src/TreeProbability.h index 0bf9d9acf..0b6593c30 100644 --- a/src/TreeProbability.h +++ b/src/TreeProbability.h @@ -87,6 +87,17 @@ class TreeProbability: public Tree { void findBestSplitValueExtraTreesUnordered(size_t nodeID, size_t varID, size_t num_classes, const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, double& best_decrease); + + void findBestSplitValueNanSmallQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease); + void findBestSplitValueNanSmallQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease, const std::vector& possible_split_values, std::vector& counter_per_class, + std::vector& counter); + void findBestSplitValueNanLargeQ(size_t nodeID, size_t varID, size_t num_classes, + const std::vector& class_counts, size_t num_samples_node, double& best_value, size_t& best_varID, + double& best_decrease); void addImpurityImportance(size_t nodeID, size_t varID, double decrease); diff --git a/src/TreeRegression.cpp b/src/TreeRegression.cpp index ec59528fb..5f62498e2 100644 --- a/src/TreeRegression.cpp +++ b/src/TreeRegression.cpp @@ -165,9 +165,17 @@ bool TreeRegression::findBestSplit(size_t nodeID, std::vector& possible_ // Use faster method for both cases double q = (double) num_samples_node / (double) data->getNumUniqueDataValues(varID); if (q < Q_THRESHOLD) { - findBestSplitValueSmallQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease); + if (data->hasNA()) { + findBestSplitValueNanSmallQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease); + } else { + findBestSplitValueSmallQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease); + } } else { - findBestSplitValueLargeQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease); + if (data->hasNA()) { + findBestSplitValueNanLargeQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease); + } else { + findBestSplitValueLargeQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease); + } } } } else { @@ -962,6 +970,220 @@ void TreeRegression::findBestSplitValueBeta(size_t nodeID, size_t varID, double } } +void TreeRegression::findBestSplitValueNanSmallQ(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, + double& best_value, size_t& best_varID, double& best_decrease) { + + // Create possible split values + std::vector possible_split_values; + data->getAllValues(possible_split_values, sampleIDs, varID, start_pos[nodeID], end_pos[nodeID]); + + // Try next variable if all equal for this + if (possible_split_values.size() < 2) { + return; + } + + const size_t num_splits = possible_split_values.size(); + if (memory_saving_splitting) { + std::vector sums_right(num_splits); + std::vector n_right(num_splits); + findBestSplitValueSmallQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease, + possible_split_values, sums_right, n_right); + } else { + std::fill_n(sums.begin(), num_splits, 0); + std::fill_n(counter.begin(), num_splits, 0); + findBestSplitValueSmallQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease, + possible_split_values, sums, counter); + } +} + +void TreeRegression::findBestSplitValueNanSmallQ(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, + double& best_value, size_t& best_varID, double& best_decrease, std::vector possible_split_values, + std::vector& sums, std::vector& counter) { + + // Counters without NaNs + double sum_nan = 0; + size_t num_samples_node_nan = 0; + + size_t last_index = possible_split_values.size() - 1; + if (std::isnan(possible_split_values[last_index])) { + for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { + size_t sampleID = sampleIDs[pos]; + + if (std::isnan(data->get_x(sampleID, varID))) { + sum_nan += data->get_y(sampleID, 0); + ++num_samples_node_nan; + } else { + size_t idx = std::lower_bound(possible_split_values.begin(), possible_split_values.end(), + data->get_x(sampleID, varID)) - possible_split_values.begin(); + + sums[idx] += data->get_y(sampleID, 0); + ++counter[idx]; + } + } + } else { + for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { + size_t sampleID = sampleIDs[pos]; + size_t idx = std::lower_bound(possible_split_values.begin(), possible_split_values.end(), + data->get_x(sampleID, varID)) - possible_split_values.begin(); + + sums[idx] += data->get_y(sampleID, 0); + ++counter[idx]; + } + } + + size_t n_left = 0; + double sum_left = 0; + + // Compute decrease of impurity for each split + for (size_t i = 0; i < possible_split_values.size() - 1; ++i) { + + // Stop if nothing here + if (counter[i] == 0) { + continue; + } + + n_left += counter[i]; + sum_left += sums[i]; + + // Stop if right child empty + size_t n_right = num_samples_node - num_samples_node_nan - n_left; + if (n_right == 0) { + break; + } + + // Stop if minimal bucket size reached + if (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0]) { + continue; + } + + double sum_right = sum_node - sum_left - sum_nan; + double decrease = sum_left * sum_left / (double) n_left + sum_right * sum_right / (double) n_right; + + double decrease_nanleft = (sum_left + sum_nan) * (sum_left + sum_nan) / (double) (n_left + num_samples_node_nan) + sum_right * sum_right / (double) n_right; + double decrease_nanright = sum_left * sum_left / (double) n_left + (sum_right + sum_nan) * (sum_right + sum_nan) / (double) (n_right + num_samples_node_nan); + + // Regularization + regularize(decrease, varID); + + // If better than before, use this + if (decrease > best_decrease) { + // Use mid-point split + best_value = (possible_split_values[i] + possible_split_values[i + 1]) / 2; + best_varID = varID; + best_decrease = decrease; + + if (decrease_nanright > decrease_nanleft) { + nan_go_right = true; + } else { + nan_go_right = false; + } + + // Use smaller value if average is numerically the same as the larger value + if (best_value == possible_split_values[i + 1]) { + best_value = possible_split_values[i]; + } + } + } +} + +void TreeRegression::findBestSplitValueNanLargeQ(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, + double& best_value, size_t& best_varID, double& best_decrease) { + + // Set counters to 0 + size_t num_unique = data->getNumUniqueDataValues(varID); + std::fill_n(counter.begin(), num_unique, 0); + std::fill_n(sums.begin(), num_unique, 0); + + // Counters without NaNs + double sum_nan = 0; + size_t num_samples_node_nan = 0; + + size_t last_index = data->getNumUniqueDataValues(varID) - 1; + if (std::isnan(data->getUniqueDataValue(varID, last_index))) { + for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { + size_t sampleID = sampleIDs[pos]; + + if (std::isnan(data->get_x(sampleID, varID))) { + sum_nan += data->get_y(sampleID, 0); + ++num_samples_node_nan; + } else { + size_t index = data->getIndex(sampleID, varID); + sums[index] += data->get_y(sampleID, 0); + ++counter[index]; + } + } + } else { + for (size_t pos = start_pos[nodeID]; pos < end_pos[nodeID]; ++pos) { + size_t sampleID = sampleIDs[pos]; + size_t index = data->getIndex(sampleID, varID); + + sums[index] += data->get_y(sampleID, 0); + ++counter[index]; + } + } + + + size_t n_left = 0; + double sum_left = 0; + + // Compute decrease of impurity for each split + for (size_t i = 0; i < num_unique - 1; ++i) { + + // Stop if nothing here + if (counter[i] == 0) { + continue; + } + + n_left += counter[i]; + sum_left += sums[i]; + + // Stop if right child empty + size_t n_right = num_samples_node - num_samples_node_nan - n_left; + if (n_right == 0) { + break; + } + + // Stop if minimal bucket size reached + if (n_left < (*min_bucket)[0] || n_right < (*min_bucket)[0]) { + continue; + } + + double sum_right = sum_node - sum_left; + double decrease = sum_left * sum_left / (double) n_left + sum_right * sum_right / (double) n_right; + + double decrease_nanleft = (sum_left + sum_nan) * (sum_left + sum_nan) / (double) (n_left + num_samples_node_nan) + sum_right * sum_right / (double) n_right; + double decrease_nanright = sum_left * sum_left / (double) n_left + (sum_right + sum_nan) * (sum_right + sum_nan) / (double) (n_right + num_samples_node_nan); + + // Regularization + regularize(decrease, varID); + + // If better than before, use this + if (decrease > best_decrease) { + // Find next value in this node + size_t j = i + 1; + while (j < num_unique && counter[j] == 0) { + ++j; + } + + // Use mid-point split + best_value = (data->getUniqueDataValue(varID, i) + data->getUniqueDataValue(varID, j)) / 2; + best_varID = varID; + best_decrease = decrease; + + if (decrease_nanright > decrease_nanleft) { + nan_go_right = true; + } else { + nan_go_right = false; + } + + // Use smaller value if average is numerically the same as the larger value + if (best_value == data->getUniqueDataValue(varID, j)) { + best_value = data->getUniqueDataValue(varID, i); + } + } + } +} + void TreeRegression::addImpurityImportance(size_t nodeID, size_t varID, double decrease) { size_t num_samples_node = end_pos[nodeID] - start_pos[nodeID]; diff --git a/src/TreeRegression.h b/src/TreeRegression.h index 84c224f63..267a21621 100644 --- a/src/TreeRegression.h +++ b/src/TreeRegression.h @@ -82,6 +82,14 @@ class TreeRegression: public Tree { void findBestSplitValueBeta(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, double& best_value, size_t& best_varID, double& best_decrease, std::vector possible_split_values, std::vector& sums_right, std::vector& n_right); + + void findBestSplitValueNanSmallQ(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, + double& best_value, size_t& best_varID, double& best_decrease); + void findBestSplitValueNanSmallQ(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, + double& best_value, size_t& best_varID, double& best_decrease, std::vector possible_split_values, + std::vector& sums, std::vector& counter); + void findBestSplitValueNanLargeQ(size_t nodeID, size_t varID, double sum_node, size_t num_samples_node, + double& best_value, size_t& best_varID, double& best_decrease); void addImpurityImportance(size_t nodeID, size_t varID, double decrease); From e779a654054c7202694f5e03d9a13065932edb45 Mon Sep 17 00:00:00 2001 From: "Marvin N. Wright" Date: Tue, 2 Jul 2024 13:49:19 +0200 Subject: [PATCH 09/11] more tests for missings --- tests/testthat/test_missings.R | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/testthat/test_missings.R b/tests/testthat/test_missings.R index 445e39e5f..272b1bd60 100644 --- a/tests/testthat/test_missings.R +++ b/tests/testthat/test_missings.R @@ -16,6 +16,8 @@ test_that("Training works with missing values in x but not in y", { dat <- iris dat[25, 1] <- NA expect_silent(ranger(Species ~ ., dat, num.trees = 5)) + expect_silent(ranger(Petal.Width ~ ., dat, num.trees = 5)) + expect_error(ranger(Sepal.Length ~ ., dat, num.trees = 5), "Missing data in dependent variable.") dat <- iris dat[4, 5] <- NA @@ -35,7 +37,7 @@ test_that("No error if missing value in irrelevant column, prediction", { expect_silent(predict(rf, dat)) }) -test_that("Prediction works with missing values", { +test_that("Prediction works with missing values, classification", { rf <- ranger(Species ~ ., iris, num.trees = 5, write.forest = TRUE) dat <- iris @@ -44,6 +46,15 @@ test_that("Prediction works with missing values", { expect_silent(predict(rf, dat)) }) +test_that("Prediction works with missing values, regression", { + rf <- ranger(Sepal.Width ~ ., iris, num.trees = 5, write.forest = TRUE) + + dat <- iris + dat[4, 4] <- NA + dat[25, 1] <- NA + expect_silent(predict(rf, dat)) +}) + test_that("Order splitting working with missing values for classification", { n <- 20 dt <- data.frame(x = sample(c("A", "B", "C", "D", NA), n, replace = TRUE), @@ -63,3 +74,10 @@ test_that("Order splitting working with missing values for multiclass classifica rf <- ranger(y ~ ., data = dt, num.trees = 5, min.node.size = n/2, respect.unordered.factors = 'order') expect_true(all(rf$forest$is.ordered)) }) + +test_that("Missing values for survival not yet working", { + dat <- veteran + dat[1, 1] <- NA + + expect_error(ranger(Surv(time, status) ~ ., dat, num.trees = 5), "Error: Missing value handling not yet implemented for survival forests\\.") +}) From 57a07ebe1505a5a9ad307055e747e545e1dd3a62 Mon Sep 17 00:00:00 2001 From: "Marvin N. Wright" Date: Tue, 2 Jul 2024 15:17:36 +0200 Subject: [PATCH 10/11] implement na.omit --- R/ranger.R | 20 ++++++++++---- tests/testthat/test_missings.R | 50 ++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 5 deletions(-) diff --git a/R/ranger.R b/R/ranger.R index adde5fcb2..73cb3a52b 100644 --- a/R/ranger.R +++ b/R/ranger.R @@ -332,16 +332,22 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, stop("Error: Missing data in dependent variable.", call. = FALSE) } } else if (na.action == "na.omit") { - # TODO: Implement na.omit - stop("na.omit not implemented yet.") + if (anyNA(x)) { + idx_keep <- stats::complete.cases(x) + x <- x[idx_keep, , drop = FALSE] + y <- y[idx_keep, drop = FALSE] + if (nrow(x) < 1) { + stop("Error: No observations left after removing missing values.") + } + } } else if (na.action == "na.learn") { if (anyNA(y)) { stop("Error: Missing data in dependent variable.", call. = FALSE) } if (anyNA(x)) { any.na <- TRUE - if (!is.null(splitrule) && !(splitrule %in% c("gini", "variance", "logrank"))) { - stop("Error: Missing value handling currently only implemented for gini, variance and logrank splitrules.") + if (!is.null(splitrule) && !(splitrule %in% c("gini", "variance"))) { + stop("Error: Missing value handling currently only implemented for gini and variance splitrules.") } } } else { @@ -378,6 +384,11 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, stop("Error: Unsupported type of dependent variable.") } + ## No missing value handling for survival yet + if (any.na & treetype == 5) { + stop("Error: Missing value handling not yet implemented for survival forests.") + } + ## Number of levels if (treetype %in% c(1, 9)) { if (is.factor(y)) { @@ -447,7 +458,6 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, ## Don't order if only one level levels.ordered <- levels(xx) } else if (inherits(y, "Surv")) { - # TODO: Fix missings here ## Use median survival if available or largest quantile available in all strata if median not available levels.ordered <- largest.quantile(y ~ xx) diff --git a/tests/testthat/test_missings.R b/tests/testthat/test_missings.R index 272b1bd60..ebeb15d8f 100644 --- a/tests/testthat/test_missings.R +++ b/tests/testthat/test_missings.R @@ -81,3 +81,53 @@ test_that("Missing values for survival not yet working", { expect_error(ranger(Surv(time, status) ~ ., dat, num.trees = 5), "Error: Missing value handling not yet implemented for survival forests\\.") }) + +test_that("na.omit leads to same result as manual removal, classification", { + dat <- iris + dat[1, 1] <- NA + rf1 <- ranger(Species ~ ., dat, num.trees = 5, seed = 10, na.action = "na.omit") + + dat2 <- na.omit(dat) + rf2 <- ranger(Species ~ ., dat2, num.trees = 5, seed = 10) + + expect_equal(rf1$predictions, rf2$predictions) +}) + +test_that("na.omit leads to same result as manual removal, probability", { + dat <- iris + dat[1, 1] <- NA + rf1 <- ranger(Species ~ ., dat, num.trees = 5, probability = TRUE, seed = 10, na.action = "na.omit") + + dat2 <- na.omit(dat) + rf2 <- ranger(Species ~ ., dat2, num.trees = 5, probability = TRUE, seed = 10) + + expect_equal(rf1$predictions, rf2$predictions) +}) + +test_that("na.omit leads to same result as manual removal, regression", { + dat <- iris + dat[1, 1] <- NA + rf1 <- ranger(Sepal.Width ~ ., dat, num.trees = 5, seed = 10, na.action = "na.omit") + + dat2 <- na.omit(dat) + rf2 <- ranger(Sepal.Width ~ ., dat2, num.trees = 5, seed = 10) + + expect_equal(rf1$predictions, rf2$predictions) +}) + +test_that("na.omit leads to same result as manual removal, survival", { + dat <- veteran + dat[1, 1] <- NA + rf1 <- ranger(Surv(time, status) ~ ., dat, num.trees = 5, seed = 10, na.action = "na.omit") + + dat2 <- na.omit(dat) + rf2 <- ranger(Surv(time, status) ~ ., dat2, num.trees = 5, seed = 10) + + expect_equal(rf1$chf, rf2$chf) +}) + +test_that("na.omit not working if no observations left", { + dat <- iris + dat[1:150, 1] <- NA + expect_error(ranger(Species ~ ., dat, num.trees = 5, na.action = "na.omit"), "Error: No observations left after removing missing values\\.") +}) From 71c3c46d377e34580367b7770cce3ad9808496da Mon Sep 17 00:00:00 2001 From: "Marvin N. Wright" Date: Wed, 3 Jul 2024 07:27:03 +0200 Subject: [PATCH 11/11] version and docs --- DESCRIPTION | 6 +++--- NEWS.md | 3 +++ R/ranger.R | 4 ++++ cpp_version/src/version.h | 2 +- man/ranger.Rd | 4 ++++ 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index edb4ca139..7e4837d09 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,8 +1,8 @@ Package: ranger Type: Package Title: A Fast Implementation of Random Forests -Version: 0.16.2 -Date: 2024-05-16 +Version: 0.16.3 +Date: 2024-07-03 Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb] Maintainer: Marvin N. Wright Description: A fast implementation of Random Forests, particularly suited for high @@ -19,7 +19,7 @@ Suggests: survival, testthat Encoding: UTF-8 -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 URL: https://imbs-hl.github.io/ranger/, https://github.com/imbs-hl/ranger BugReports: https://github.com/imbs-hl/ranger/issues diff --git a/NEWS.md b/NEWS.md index 708b2316c..cfb9d3535 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,4 +1,7 @@ +# ranger 0.16.3 +* Add handling of missing values for classification and regression + # ranger 0.16.2 * Add Poisson splitting rule for regression trees diff --git a/R/ranger.R b/R/ranger.R index 97be99cbc..115ea3555 100644 --- a/R/ranger.R +++ b/R/ranger.R @@ -85,6 +85,10 @@ ##' If \code{regularization.usedepth=TRUE}, \eqn{f^d} is used, where \emph{f} is the regularization factor and \emph{d} the depth of the node. ##' If regularization is used, multithreading is deactivated because all trees need access to the list of variables that are already included in the model. ##' +##' Missing values can be internally handled by setting \code{na.action = "na.learn"} (default), by omitting observations with missing values with \code{na.action = "na.omit"} or by stopping if missing values are found with \code{na.action = "na.fail"}. +##' With \code{na.action = "na.learn"}, missing values are ignored for calculating an initial split criterion value (i.e., decrease of impurity). Then for the best split, all missings are tried in both child nodes and the choice is made based again on the split criterion value. +##' For prediction, this direction is saved as the "default" direction. If a missing occurs in prediction at a node where there is no default direction, it goes left. +##' ##' For a large number of variables and data frames as input data the formula interface can be slow or impossible to use. ##' Alternatively \code{dependent.variable.name} (and \code{status.variable.name} for survival) or \code{x} and \code{y} can be used. ##' Use \code{x} and \code{y} with a matrix for \code{x} to avoid conversions and save memory. diff --git a/cpp_version/src/version.h b/cpp_version/src/version.h index 673de5f2a..4aeaffdf8 100644 --- a/cpp_version/src/version.h +++ b/cpp_version/src/version.h @@ -1,3 +1,3 @@ #ifndef RANGER_VERSION -#define RANGER_VERSION "0.16.2" +#define RANGER_VERSION "0.16.3" #endif diff --git a/man/ranger.Rd b/man/ranger.Rd index fead319f5..d8d40f305 100644 --- a/man/ranger.Rd +++ b/man/ranger.Rd @@ -228,6 +228,10 @@ Regularization works by penalizing new variables by multiplying the splitting cr If \code{regularization.usedepth=TRUE}, \eqn{f^d} is used, where \emph{f} is the regularization factor and \emph{d} the depth of the node. If regularization is used, multithreading is deactivated because all trees need access to the list of variables that are already included in the model. +Missing values can be internally handled by setting \code{na.action = "na.learn"} (default), by omitting observations with missing values with \code{na.action = "na.omit"} or by stopping if missing values are found with \code{na.action = "na.fail"}. +With \code{na.action = "na.learn"}, missing values are ignored for calculating an initial split criterion value (i.e., decrease of impurity). Then for the best split, all missings are tried in both child nodes and the choice is made based again on the split criterion value. +For prediction, this direction is saved as the "default" direction. If a missing occurs in prediction at a node where there is no default direction, it goes left. + For a large number of variables and data frames as input data the formula interface can be slow or impossible to use. Alternatively \code{dependent.variable.name} (and \code{status.variable.name} for survival) or \code{x} and \code{y} can be used. Use \code{x} and \code{y} with a matrix for \code{x} to avoid conversions and save memory.