Skip to content

Commit

Permalink
add split statistics per node
Browse files Browse the repository at this point in the history
  • Loading branch information
mnwright committed Nov 7, 2023
1 parent d460960 commit 8faf91e
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 21 deletions.
2 changes: 1 addition & 1 deletion R/ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
##' @param num.threads Number of threads. Default is number of CPUs available.
##' @param save.memory Use memory saving (but slower) splitting mode. No effect for survival and GWAS data. Warning: This option slows down the tree growing, use only if you encounter memory problems.
##' @param verbose Show computation status and estimated runtime.
##' @param node.stats Save node statistics. Set to \code{TRUE} to save prediction and number of observations for each node.
##' @param node.stats Save node statistics. Set to \code{TRUE} to save prediction, number of observations and split statistics for each node.
##' @param seed Random seed. Default is \code{NULL}, which generates the seed from \code{R}. Set to \code{0} to ignore the \code{R} seed.
##' @param dependent.variable.name Name of dependent variable, needed if no formula given. For survival forests this is the time variable.
##' @param status.variable.name Name of status variable, only applicable to survival data and needed if no formula given. Use 1 for event and 0 for censoring.
Expand Down
6 changes: 6 additions & 0 deletions R/treeInfo.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
#' \code{splitval} \tab The splitting value. For numeric or ordinal variables, all values smaller or equal go to the left, larger values to the right. For unordered factor variables see above. \cr
#' \code{terminal} \tab Logical, TRUE for terminal nodes. \cr
#' \code{prediction} \tab One column with the predicted class (factor) for classification and the predicted numerical value for regression. One probability per class for probability estimation in several columns. Nothing for survival, refer to \code{object$forest$chf} for the CHF node predictions. \cr
#' \code{numSamples} \tab Number of samples in the node (only if ranger called with \code{node.stats = TRUE}). \cr
#' \code{splitStat} \tab Split statistics, i.e., value of the splitting criterion (only if ranger called with \code{node.stats = TRUE}). \cr
#' }
#' @examples
#' rf <- ranger(Species ~ ., data = iris)
Expand Down Expand Up @@ -164,6 +166,10 @@ treeInfo <- function(object, tree = 1) {
if (!is.null(forest$num.samples.nodes)) {
result$numSamples <- forest$num.samples.nodes[[tree]]
}
if (!is.null(forest$split.stats)) {
result$splitStat <- forest$split.stats[[tree]]
result$splitStat[result$terminal] <- NA
}

result
}
2 changes: 1 addition & 1 deletion man/ranger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/treeInfo.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions src/Forest.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@ class Forest {
}
return result;
}
std::vector<std::vector<double>> getSplitStats() {
std::vector<std::vector<double>> result;
for (auto& tree : trees) {
result.push_back(tree->getSplitStats());
}
return result;
}

protected:
void grow();
Expand Down
1 change: 1 addition & 0 deletions src/Tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ void Tree::createEmptyNode() {

if (save_node_stats) {
num_samples_nodes.push_back(0);
split_stats.push_back(0);
}

createEmptyNodeInternal();
Expand Down
4 changes: 4 additions & 0 deletions src/Tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class Tree {
const std::vector<double>& getNodePredictions() const {
return node_predictions;
}
const std::vector<double>& getSplitStats() const {
return split_stats;
}

protected:
void createPossibleSplitVarSubset(std::vector<size_t>& result);
Expand Down Expand Up @@ -203,6 +206,7 @@ class Tree {
bool save_node_stats;
std::vector<size_t> num_samples_nodes;
std::vector<double> node_predictions;
std::vector<double> split_stats;

// Holdout mode
bool holdout;
Expand Down
10 changes: 10 additions & 0 deletions src/TreeClassification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,11 @@ bool TreeClassification::findBestSplit(size_t nodeID, std::vector<size_t>& possi
// Save best values
split_varIDs[nodeID] = best_varID;
split_values[nodeID] = best_value;

// Save split statistics
if (save_node_stats) {
split_stats[nodeID] = best_decrease;
}

// Compute gini index for this node and to variable importance if needed
if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) {
Expand Down Expand Up @@ -564,6 +569,11 @@ bool TreeClassification::findBestSplitExtraTrees(size_t nodeID, std::vector<size
// Save best values
split_varIDs[nodeID] = best_varID;
split_values[nodeID] = best_value;

// Save split statistics
if (save_node_stats) {
split_stats[nodeID] = best_decrease;
}

// Compute gini index for this node and to variable importance if needed
if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) {
Expand Down
10 changes: 10 additions & 0 deletions src/TreeProbability.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ bool TreeProbability::findBestSplit(size_t nodeID, std::vector<size_t>& possible
// Save best values
split_varIDs[nodeID] = best_varID;
split_values[nodeID] = best_value;

// Save split statistics
if (save_node_stats) {
split_stats[nodeID] = best_decrease;
}

// Compute decrease of impurity for this node and add to variable importance if needed
if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) {
Expand Down Expand Up @@ -568,6 +573,11 @@ bool TreeProbability::findBestSplitExtraTrees(size_t nodeID, std::vector<size_t>
// Save best values
split_varIDs[nodeID] = best_varID;
split_values[nodeID] = best_value;

// Save split statistics
if (save_node_stats) {
split_stats[nodeID] = best_decrease;
}

// Compute decrease of impurity for this node and add to variable importance if needed
if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) {
Expand Down
20 changes: 20 additions & 0 deletions src/TreeRegression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ bool TreeRegression::findBestSplit(size_t nodeID, std::vector<size_t>& possible_
// Save best values
split_varIDs[nodeID] = best_varID;
split_values[nodeID] = best_value;

// Save split statistics
if (save_node_stats) {
split_stats[nodeID] = best_decrease;
}

// Compute decrease of impurity for this node and add to variable importance if needed
if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) {
Expand Down Expand Up @@ -513,6 +518,11 @@ bool TreeRegression::findBestSplitMaxstat(size_t nodeID, std::vector<size_t>& po
// If not terminal node save best values
split_varIDs[nodeID] = best_varID;
split_values[nodeID] = best_value;

// Save split statistics
if (save_node_stats) {
split_stats[nodeID] = best_maxstat;
}

// Compute decrease of impurity for this node and add to variable importance if needed
if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) {
Expand Down Expand Up @@ -561,6 +571,11 @@ bool TreeRegression::findBestSplitExtraTrees(size_t nodeID, std::vector<size_t>&
// Save best values
split_varIDs[nodeID] = best_varID;
split_values[nodeID] = best_value;

// Save split statistics
if (save_node_stats) {
split_stats[nodeID] = best_decrease;
}

// Compute decrease of impurity for this node and add to variable importance if needed
if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) {
Expand Down Expand Up @@ -794,6 +809,11 @@ bool TreeRegression::findBestSplitBeta(size_t nodeID, std::vector<size_t>& possi
// Save best values
split_varIDs[nodeID] = best_varID;
split_values[nodeID] = best_value;

// Save split statistics
if (save_node_stats) {
split_stats[nodeID] = best_decrease;
}

// Compute decrease of impurity for this node and add to variable importance if needed
if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) {
Expand Down
15 changes: 15 additions & 0 deletions src/TreeSurvival.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ bool TreeSurvival::findBestSplit(size_t nodeID, std::vector<size_t>& possible_sp
// If not terminal node save best values
split_varIDs[nodeID] = best_varID;
split_values[nodeID] = best_value;

// Save split statistics
if (save_node_stats) {
split_stats[nodeID] = best_decrease;
}

// Compute decrease of impurity for this node and add to variable importance if needed
if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) {
Expand Down Expand Up @@ -308,6 +313,11 @@ bool TreeSurvival::findBestSplitMaxstat(size_t nodeID, std::vector<size_t>& poss
// If not terminal node save best values
split_varIDs[nodeID] = best_varID;
split_values[nodeID] = best_value;

// Save split statistics
if (save_node_stats) {
split_stats[nodeID] = best_maxstat;
}

// Compute decrease of impurity for this node and add to variable importance if needed
if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) {
Expand Down Expand Up @@ -734,6 +744,11 @@ bool TreeSurvival::findBestSplitExtraTrees(size_t nodeID, std::vector<size_t>& p
// If not terminal node save best values
split_varIDs[nodeID] = best_varID;
split_values[nodeID] = best_value;

// Save split statistics
if (save_node_stats) {
split_stats[nodeID] = best_decrease;
}

// Compute decrease of impurity for this node and add to variable importance if needed
if (importance_mode == IMP_GINI || importance_mode == IMP_GINI_CORRECTED) {
Expand Down
1 change: 1 addition & 0 deletions src/rangerCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ Rcpp::List rangerCpp(uint treetype, Rcpp::NumericMatrix& input_x, Rcpp::NumericM

if (node_stats) {
forest_object.push_back(forest->getNumSamplesNodes(), "num.samples.nodes");
forest_object.push_back(forest->getSplitStats(), "split.stats");
}

if (snp_data.nrow() > 1 && order_snps) {
Expand Down
39 changes: 20 additions & 19 deletions tests/testthat/test_nodestats.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,29 @@ test_that("if node.stats FALSE, no nodestats saved, classification", {
rf <- ranger(Species ~ ., iris, num.trees = 5)
expect_null(rf$forest$num.samples.nodes)
expect_null(rf$forest$node.predictions)
expect_null(rf$forest$split.stats)
})

test_that("if node.stats FALSE, no nodestats saved, probability", {
rf <- ranger(Species ~ ., iris, num.trees = 5, probability = TRUE)
expect_null(rf$forest$num.samples.nodes)
expect_null(rf$forest$node.predictions)
expect_null(rf$forest$split.stats)
expect_length(rf$forest$terminal.class.counts[[1]][[1]], 0)
})

test_that("if node.stats FALSE, no nodestats saved, regression", {
rf <- ranger(Sepal.Length ~ ., iris, num.trees = 5)
expect_null(rf$forest$num.samples.nodes)
expect_null(rf$forest$node.predictions)
expect_null(rf$forest$split.stats)
})

test_that("if node.stats FALSE, no nodestats saved, survival", {
rf <- ranger(Surv(time, status) ~ ., veteran, num.trees = 5)
expect_null(rf$forest$num.samples.nodes)
expect_null(rf$forest$node.predictions)
expect_null(rf$forest$split.stats)
expect_length(rf$forest$chf[[1]][[1]], 0)
})

Expand All @@ -40,6 +44,10 @@ test_that("if node.stats TRUE, nodestats saved, classification", {
expect_is(rf$forest$node.predictions, "list")
expect_length(rf$forest$node.predictions, rf$num.trees)
expect_is(rf$forest$node.predictions[[1]], "numeric")

expect_is(rf$forest$split.stats, "list")
expect_length(rf$forest$split.stats, rf$num.trees)
expect_is(rf$forest$split.stats[[1]], "numeric")
})

test_that("if node.stats TRUE, nodestats saved, probability", {
Expand All @@ -54,6 +62,10 @@ test_that("if node.stats TRUE, nodestats saved, probability", {
expect_is(rf$forest$terminal.class.counts, "list")
expect_length(rf$forest$terminal.class.counts, rf$num.trees)
expect_length(rf$forest$terminal.class.counts[[1]][[1]], nlevels(iris$Species))

expect_is(rf$forest$split.stats, "list")
expect_length(rf$forest$split.stats, rf$num.trees)
expect_is(rf$forest$split.stats[[1]], "numeric")
})

test_that("if node.stats TRUE, nodestats saved, regression", {
Expand All @@ -66,6 +78,10 @@ test_that("if node.stats TRUE, nodestats saved, regression", {
expect_is(rf$forest$node.predictions, "list")
expect_length(rf$forest$node.predictions, rf$num.trees)
expect_is(rf$forest$node.predictions[[1]], "numeric")

expect_is(rf$forest$split.stats, "list")
expect_length(rf$forest$split.stats, rf$num.trees)
expect_is(rf$forest$split.stats[[1]], "numeric")
})

test_that("if node.stats TRUE, nodestats saved, survival", {
Expand All @@ -82,26 +98,11 @@ test_that("if node.stats TRUE, nodestats saved, survival", {
expect_is(rf$forest$chf[[1]], "list")
expect_is(rf$forest$chf[[1]][[1]], "numeric")
expect_length(rf$forest$chf[[1]][[1]], length(rf$unique.death.times))

expect_is(rf$forest$split.stats, "list")
expect_length(rf$forest$split.stats, rf$num.trees)
expect_is(rf$forest$split.stats[[1]], "numeric")
})



rf <- ranger(Species ~ ., iris, num.trees = 10, probability = TRUE, node.stats = TRUE)
rf$forest$num.samples.nodes
rf$forest$node.predictions
rf$forest$terminal.class.counts


rf <- ranger(Sepal.Length ~ ., iris, num.trees = 10, node.stats = TRUE)
rf$forest$num.samples.nodes
rf$forest$node.predictions

# Survival

rf <- ranger(Surv(time, status) ~ ., veteran, num.trees = 10, node.stats = TRUE)
rf$forest$num.samples.nodes
rf$forest$node.predictions
rf$forest$chf



0 comments on commit 8faf91e

Please sign in to comment.