diff --git a/R/ranger.R b/R/ranger.R index 9691dde0..e2059282 100644 --- a/R/ranger.R +++ b/R/ranger.R @@ -120,7 +120,7 @@ ##' @param num.threads Number of threads. Default is number of CPUs available. ##' @param save.memory Use memory saving (but slower) splitting mode. No effect for survival and GWAS data. Warning: This option slows down the tree growing, use only if you encounter memory problems. ##' @param verbose Show computation status and estimated runtime. -##' @param node.stats Save node statistics. Set to \code{TRUE} to save prediction and number of observations for each node. +##' @param node.stats Save node statistics. Set to \code{TRUE} to save prediction, number of observations and split statistics for each node. ##' @param seed Random seed. Default is \code{NULL}, which generates the seed from \code{R}. Set to \code{0} to ignore the \code{R} seed. ##' @param dependent.variable.name Name of dependent variable, needed if no formula given. For survival forests this is the time variable. ##' @param status.variable.name Name of status variable, only applicable to survival data and needed if no formula given. Use 1 for event and 0 for censoring. diff --git a/R/treeInfo.R b/R/treeInfo.R index 07139931..184eabd8 100644 --- a/R/treeInfo.R +++ b/R/treeInfo.R @@ -52,6 +52,8 @@ #' \code{splitval} \tab The splitting value. For numeric or ordinal variables, all values smaller or equal go to the left, larger values to the right. For unordered factor variables see above. \cr #' \code{terminal} \tab Logical, TRUE for terminal nodes. \cr #' \code{prediction} \tab One column with the predicted class (factor) for classification and the predicted numerical value for regression. One probability per class for probability estimation in several columns. Nothing for survival, refer to \code{object$forest$chf} for the CHF node predictions. \cr +#' \code{numSamples} \tab Number of samples in the node (only if ranger called with \code{node.stats = TRUE}). \cr +#' \code{splitStat} \tab Split statistics, i.e., value of the splitting criterion (only if ranger called with \code{node.stats = TRUE}). \cr #' } #' @examples #' rf <- ranger(Species ~ ., data = iris) @@ -164,6 +166,10 @@ treeInfo <- function(object, tree = 1) { if (!is.null(forest$num.samples.nodes)) { result$numSamples <- forest$num.samples.nodes[[tree]] } + if (!is.null(forest$split.stats)) { + result$splitStat <- forest$split.stats[[tree]] + result$splitStat[result$terminal] <- NA + } result } diff --git a/man/ranger.Rd b/man/ranger.Rd index e67b6810..da464a1b 100644 --- a/man/ranger.Rd +++ b/man/ranger.Rd @@ -118,7 +118,7 @@ ranger( \item{verbose}{Show computation status and estimated runtime.} -\item{node.stats}{Save node statistics. Set to \code{TRUE} to save prediction and number of observations for each node.} +\item{node.stats}{Save node statistics. Set to \code{TRUE} to save prediction, number of observations and split statistics for each node.} \item{seed}{Random seed. Default is \code{NULL}, which generates the seed from \code{R}. Set to \code{0} to ignore the \code{R} seed.} diff --git a/man/treeInfo.Rd b/man/treeInfo.Rd index 36033052..cdd12581 100644 --- a/man/treeInfo.Rd +++ b/man/treeInfo.Rd @@ -22,6 +22,8 @@ A data.frame with the columns \code{splitval} \tab The splitting value. For numeric or ordinal variables, all values smaller or equal go to the left, larger values to the right. For unordered factor variables see above. \cr \code{terminal} \tab Logical, TRUE for terminal nodes. \cr \code{prediction} \tab One column with the predicted class (factor) for classification and the predicted numerical value for regression. One probability per class for probability estimation in several columns. Nothing for survival, refer to \code{object$forest$chf} for the CHF node predictions. \cr + \code{numSamples} \tab Number of samples in the node (only if ranger called with \code{node.stats = TRUE}). \cr + \code{splitStat} \tab Split statistics, i.e., value of the splitting criterion (only if ranger called with \code{node.stats = TRUE}). \cr } } \description{ diff --git a/src/Forest.h b/src/Forest.h index 73d782dc..5b297202 100644 --- a/src/Forest.h +++ b/src/Forest.h @@ -159,6 +159,13 @@ class Forest { } return result; } + std::vector> getSplitStats() { + std::vector> result; + for (auto& tree : trees) { + result.push_back(tree->getSplitStats()); + } + return result; + } protected: void grow(); diff --git a/src/Tree.cpp b/src/Tree.cpp index 542e540a..57d3dfbd 100644 --- a/src/Tree.cpp +++ b/src/Tree.cpp @@ -390,6 +390,7 @@ void Tree::createEmptyNode() { if (save_node_stats) { num_samples_nodes.push_back(0); + split_stats.push_back(0); } createEmptyNodeInternal(); diff --git a/src/Tree.h b/src/Tree.h index 3536ce68..101c300d 100644 --- a/src/Tree.h +++ b/src/Tree.h @@ -82,6 +82,9 @@ class Tree { const std::vector& getNodePredictions() const { return node_predictions; } + const std::vector& getSplitStats() const { + return split_stats; + } protected: void createPossibleSplitVarSubset(std::vector& result); @@ -203,6 +206,7 @@ class Tree { bool save_node_stats; std::vector num_samples_nodes; std::vector node_predictions; + std::vector split_stats; // Holdout mode bool holdout; diff --git a/src/TreeClassification.cpp b/src/TreeClassification.cpp index 23bd8acf..bbb3b581 100644 --- a/src/TreeClassification.cpp +++ b/src/TreeClassification.cpp @@ -205,6 +205,11 @@ bool TreeClassification::findBestSplit(size_t nodeID, std::vector& possi // Save best values split_varIDs[nodeID] = best_varID; split_values[nodeID] = best_value; + + // Save split statistics + if (save_node_stats) { + split_stats[nodeID] = best_decrease; + } // Compute gini index for this node and to variable importance if needed if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { @@ -564,6 +569,11 @@ bool TreeClassification::findBestSplitExtraTrees(size_t nodeID, std::vector& possible // Save best values split_varIDs[nodeID] = best_varID; split_values[nodeID] = best_value; + + // Save split statistics + if (save_node_stats) { + split_stats[nodeID] = best_decrease; + } // Compute decrease of impurity for this node and add to variable importance if needed if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { @@ -568,6 +573,11 @@ bool TreeProbability::findBestSplitExtraTrees(size_t nodeID, std::vector // Save best values split_varIDs[nodeID] = best_varID; split_values[nodeID] = best_value; + + // Save split statistics + if (save_node_stats) { + split_stats[nodeID] = best_decrease; + } // Compute decrease of impurity for this node and add to variable importance if needed if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { diff --git a/src/TreeRegression.cpp b/src/TreeRegression.cpp index abb4120d..c272695b 100644 --- a/src/TreeRegression.cpp +++ b/src/TreeRegression.cpp @@ -184,6 +184,11 @@ bool TreeRegression::findBestSplit(size_t nodeID, std::vector& possible_ // Save best values split_varIDs[nodeID] = best_varID; split_values[nodeID] = best_value; + + // Save split statistics + if (save_node_stats) { + split_stats[nodeID] = best_decrease; + } // Compute decrease of impurity for this node and add to variable importance if needed if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { @@ -513,6 +518,11 @@ bool TreeRegression::findBestSplitMaxstat(size_t nodeID, std::vector& po // If not terminal node save best values split_varIDs[nodeID] = best_varID; split_values[nodeID] = best_value; + + // Save split statistics + if (save_node_stats) { + split_stats[nodeID] = best_maxstat; + } // Compute decrease of impurity for this node and add to variable importance if needed if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { @@ -561,6 +571,11 @@ bool TreeRegression::findBestSplitExtraTrees(size_t nodeID, std::vector& // Save best values split_varIDs[nodeID] = best_varID; split_values[nodeID] = best_value; + + // Save split statistics + if (save_node_stats) { + split_stats[nodeID] = best_decrease; + } // Compute decrease of impurity for this node and add to variable importance if needed if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { @@ -794,6 +809,11 @@ bool TreeRegression::findBestSplitBeta(size_t nodeID, std::vector& possi // Save best values split_varIDs[nodeID] = best_varID; split_values[nodeID] = best_value; + + // Save split statistics + if (save_node_stats) { + split_stats[nodeID] = best_decrease; + } // Compute decrease of impurity for this node and add to variable importance if needed if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { diff --git a/src/TreeSurvival.cpp b/src/TreeSurvival.cpp index 678522b9..1c60ba8b 100644 --- a/src/TreeSurvival.cpp +++ b/src/TreeSurvival.cpp @@ -177,6 +177,11 @@ bool TreeSurvival::findBestSplit(size_t nodeID, std::vector& possible_sp // If not terminal node save best values split_varIDs[nodeID] = best_varID; split_values[nodeID] = best_value; + + // Save split statistics + if (save_node_stats) { + split_stats[nodeID] = best_decrease; + } // Compute decrease of impurity for this node and add to variable importance if needed if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { @@ -308,6 +313,11 @@ bool TreeSurvival::findBestSplitMaxstat(size_t nodeID, std::vector& poss // If not terminal node save best values split_varIDs[nodeID] = best_varID; split_values[nodeID] = best_value; + + // Save split statistics + if (save_node_stats) { + split_stats[nodeID] = best_maxstat; + } // Compute decrease of impurity for this node and add to variable importance if needed if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { @@ -734,6 +744,11 @@ bool TreeSurvival::findBestSplitExtraTrees(size_t nodeID, std::vector& p // If not terminal node save best values split_varIDs[nodeID] = best_varID; split_values[nodeID] = best_value; + + // Save split statistics + if (save_node_stats) { + split_stats[nodeID] = best_decrease; + } // Compute decrease of impurity for this node and add to variable importance if needed if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) { diff --git a/src/rangerCpp.cpp b/src/rangerCpp.cpp index 757381b6..c8c4fed2 100644 --- a/src/rangerCpp.cpp +++ b/src/rangerCpp.cpp @@ -260,6 +260,7 @@ Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericM if (node_stats) { forest_object.push_back(forest->getNumSamplesNodes(), "num.samples.nodes"); + forest_object.push_back(forest->getSplitStats(), "split.stats"); } if (snp_data.nrow() > 1 && order_snps) { diff --git a/tests/testthat/test_nodestats.R b/tests/testthat/test_nodestats.R index 3470b927..538563bb 100644 --- a/tests/testthat/test_nodestats.R +++ b/tests/testthat/test_nodestats.R @@ -8,12 +8,14 @@ test_that("if node.stats FALSE, no nodestats saved, classification", { rf <- ranger(Species ~ ., iris, num.trees = 5) expect_null(rf$forest$num.samples.nodes) expect_null(rf$forest$node.predictions) + expect_null(rf$forest$split.stats) }) test_that("if node.stats FALSE, no nodestats saved, probability", { rf <- ranger(Species ~ ., iris, num.trees = 5, probability = TRUE) expect_null(rf$forest$num.samples.nodes) expect_null(rf$forest$node.predictions) + expect_null(rf$forest$split.stats) expect_length(rf$forest$terminal.class.counts[[1]][[1]], 0) }) @@ -21,12 +23,14 @@ test_that("if node.stats FALSE, no nodestats saved, regression", { rf <- ranger(Sepal.Length ~ ., iris, num.trees = 5) expect_null(rf$forest$num.samples.nodes) expect_null(rf$forest$node.predictions) + expect_null(rf$forest$split.stats) }) test_that("if node.stats FALSE, no nodestats saved, survival", { rf <- ranger(Surv(time, status) ~ ., veteran, num.trees = 5) expect_null(rf$forest$num.samples.nodes) expect_null(rf$forest$node.predictions) + expect_null(rf$forest$split.stats) expect_length(rf$forest$chf[[1]][[1]], 0) }) @@ -40,6 +44,10 @@ test_that("if node.stats TRUE, nodestats saved, classification", { expect_is(rf$forest$node.predictions, "list") expect_length(rf$forest$node.predictions, rf$num.trees) expect_is(rf$forest$node.predictions[[1]], "numeric") + + expect_is(rf$forest$split.stats, "list") + expect_length(rf$forest$split.stats, rf$num.trees) + expect_is(rf$forest$split.stats[[1]], "numeric") }) test_that("if node.stats TRUE, nodestats saved, probability", { @@ -54,6 +62,10 @@ test_that("if node.stats TRUE, nodestats saved, probability", { expect_is(rf$forest$terminal.class.counts, "list") expect_length(rf$forest$terminal.class.counts, rf$num.trees) expect_length(rf$forest$terminal.class.counts[[1]][[1]], nlevels(iris$Species)) + + expect_is(rf$forest$split.stats, "list") + expect_length(rf$forest$split.stats, rf$num.trees) + expect_is(rf$forest$split.stats[[1]], "numeric") }) test_that("if node.stats TRUE, nodestats saved, regression", { @@ -66,6 +78,10 @@ test_that("if node.stats TRUE, nodestats saved, regression", { expect_is(rf$forest$node.predictions, "list") expect_length(rf$forest$node.predictions, rf$num.trees) expect_is(rf$forest$node.predictions[[1]], "numeric") + + expect_is(rf$forest$split.stats, "list") + expect_length(rf$forest$split.stats, rf$num.trees) + expect_is(rf$forest$split.stats[[1]], "numeric") }) test_that("if node.stats TRUE, nodestats saved, survival", { @@ -82,26 +98,11 @@ test_that("if node.stats TRUE, nodestats saved, survival", { expect_is(rf$forest$chf[[1]], "list") expect_is(rf$forest$chf[[1]][[1]], "numeric") expect_length(rf$forest$chf[[1]][[1]], length(rf$unique.death.times)) + + expect_is(rf$forest$split.stats, "list") + expect_length(rf$forest$split.stats, rf$num.trees) + expect_is(rf$forest$split.stats[[1]], "numeric") }) -rf <- ranger(Species ~ ., iris, num.trees = 10, probability = TRUE, node.stats = TRUE) -rf$forest$num.samples.nodes -rf$forest$node.predictions -rf$forest$terminal.class.counts - - -rf <- ranger(Sepal.Length ~ ., iris, num.trees = 10, node.stats = TRUE) -rf$forest$num.samples.nodes -rf$forest$node.predictions - -# Survival - -rf <- ranger(Surv(time, status) ~ ., veteran, num.trees = 10, node.stats = TRUE) -rf$forest$num.samples.nodes -rf$forest$node.predictions -rf$forest$chf - - -