Skip to content

Commit

Permalink
Merge pull request #690 from imbs-hl/split_stats
Browse files Browse the repository at this point in the history
Add split statistics option
  • Loading branch information
mnwright authored Nov 8, 2023
2 parents 5550eaf + 53e6455 commit 9722742
Show file tree
Hide file tree
Showing 23 changed files with 372 additions and 53 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Package: ranger
Type: Package
Title: A Fast Implementation of Random Forests
Version: 0.15.4
Date: 2023-11-03
Date: 2023-11-07
Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb]
Maintainer: Marvin N. Wright <[email protected]>
Description: A fast implementation of Random Forests, particularly suited for high
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

# ranger 0.15.4
* Add node.stats option to save node statistics of all nodes
* Add time.interest option to restrict unique survival times (faster and saves memory)

# ranger 0.15.3
Expand Down
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, time_interest, use_time_interest) {
.Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, time_interest, use_time_interest)
rangerCpp <- function(treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest) {
.Call(`_ranger_rangerCpp`, treetype, input_x, input_y, variable_names, mtry, num_trees, verbose, seed, num_threads, write_forest, importance_mode_r, min_node_size, min_bucket, split_select_weights, use_split_select_weights, always_split_variable_names, use_always_split_variable_names, prediction_mode, loaded_forest, snp_data, sample_with_replacement, probability, unordered_variable_names, use_unordered_variable_names, save_memory, splitrule_r, case_weights, use_case_weights, class_weights, predict_all, keep_inbag, sample_fraction, alpha, minprop, holdout, prediction_type_r, num_random_splits, sparse_x, use_sparse_data, order_snps, oob_error, max_depth, inbag, use_inbag, regularization_factor, use_regularization_factor, regularization_usedepth, node_stats, time_interest, use_time_interest)
}

numSmaller <- function(values, reference) {
Expand Down
3 changes: 2 additions & 1 deletion R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
regularization.factor <- c(0, 0)
use.regularization.factor <- FALSE
regularization.usedepth <- FALSE
node.stats <- FALSE
time.interest <- c(0, 0)
use.time.interest <- FALSE

Expand All @@ -276,7 +277,7 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
prediction.type, num.random.splits, sparse.x, use.sparse.data,
order.snps, oob.error, max.depth, inbag, use.inbag,
regularization.factor, use.regularization.factor, regularization.usedepth,
time.interest, use.time.interest)
node.stats, time.interest, use.time.interest)

if (length(result) == 0) {
stop("User interrupt or internal error.")
Expand Down
5 changes: 3 additions & 2 deletions R/ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
##' @param num.threads Number of threads. Default is number of CPUs available.
##' @param save.memory Use memory saving (but slower) splitting mode. No effect for survival and GWAS data. Warning: This option slows down the tree growing, use only if you encounter memory problems.
##' @param verbose Show computation status and estimated runtime.
##' @param node.stats Save node statistics. Set to \code{TRUE} to save prediction, number of observations and split statistics for each node.
##' @param seed Random seed. Default is \code{NULL}, which generates the seed from \code{R}. Set to \code{0} to ignore the \code{R} seed.
##' @param dependent.variable.name Name of dependent variable, needed if no formula given. For survival forests this is the time variable.
##' @param status.variable.name Name of status variable, only applicable to survival data and needed if no formula given. Use 1 for event and 0 for censoring.
Expand Down Expand Up @@ -244,7 +245,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
keep.inbag = FALSE, inbag = NULL, holdout = FALSE,
quantreg = FALSE, time.interest = NULL, oob.error = TRUE,
num.threads = NULL, save.memory = FALSE,
verbose = TRUE, seed = NULL,
verbose = TRUE, node.stats = FALSE, seed = NULL,
dependent.variable.name = NULL, status.variable.name = NULL,
classification = NULL, x = NULL, y = NULL, ...) {

Expand Down Expand Up @@ -924,7 +925,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
num.random.splits, sparse.x, use.sparse.data, order.snps, oob.error, max.depth,
inbag, use.inbag,
regularization.factor, use.regularization.factor, regularization.usedepth,
time.interest, use.time.interest)
node.stats, time.interest, use.time.interest)

if (length(result) == 0) {
stop("User interrupt or internal error.")
Expand Down
39 changes: 34 additions & 5 deletions R/treeInfo.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
#' \code{splitval} \tab The splitting value. For numeric or ordinal variables, all values smaller or equal go to the left, larger values to the right. For unordered factor variables see above. \cr
#' \code{terminal} \tab Logical, TRUE for terminal nodes. \cr
#' \code{prediction} \tab One column with the predicted class (factor) for classification and the predicted numerical value for regression. One probability per class for probability estimation in several columns. Nothing for survival, refer to \code{object$forest$chf} for the CHF node predictions. \cr
#' \code{numSamples} \tab Number of samples in the node (only if ranger called with \code{node.stats = TRUE}). \cr
#' \code{splitStat} \tab Split statistics, i.e., value of the splitting criterion (only if ranger called with \code{node.stats = TRUE}). \cr
#' }
#' @examples
#' rf <- ranger(Species ~ ., data = iris)
Expand Down Expand Up @@ -117,17 +119,35 @@ treeInfo <- function(object, tree = 1) {

## Prediction
if (forest$treetype == "Classification") {
result$prediction <- forest$split.values[[tree]]
result$prediction[!result$terminal] <- NA
if (is.null(forest$num.samples.nodes)) {
# split.stats=FALSE
result$prediction <- forest$split.values[[tree]]
result$prediction[!result$terminal] <- NA
} else {
# split.stats=TRUE
result$prediction <- forest$node.predictions[[tree]]
}
if (!is.null(forest$levels)) {
result$prediction <- integer.to.factor(result$prediction, labels = forest$levels)
}
} else if (forest$treetype == "Regression") {
result$prediction <- forest$split.values[[tree]]
result$prediction[!result$terminal] <- NA
if (is.null(forest$num.samples.nodes)) {
# split.stats=FALSE
result$prediction <- forest$split.values[[tree]]
result$prediction[!result$terminal] <- NA
} else {
# split.stats=TRUE
result$prediction <- forest$node.predictions[[tree]]
}
} else if (forest$treetype == "Probability estimation") {
predictions <- matrix(nrow = nrow(result), ncol = length(forest$class.values))
predictions[result$terminal, ] <- do.call(rbind, forest$terminal.class.counts[[tree]])
if (is.null(forest$num.samples.nodes)) {
# split.stats=FALSE
predictions[result$terminal, ] <- do.call(rbind, forest$terminal.class.counts[[tree]])
} else {
# split.stats=TRUE
predictions <- do.call(rbind, forest$terminal.class.counts[[tree]])
}
if (!is.null(forest$levels)) {
colnames(predictions) <- forest$levels[forest$class.values]
predictions <- predictions[, forest$levels[sort(forest$class.values)], drop = FALSE]
Expand All @@ -142,5 +162,14 @@ treeInfo <- function(object, tree = 1) {
stop("Error: Unknown tree type.")
}

## Node statistics
if (!is.null(forest$num.samples.nodes)) {
result$numSamples <- forest$num.samples.nodes[[tree]]
}
if (!is.null(forest$split.stats)) {
result$splitStat <- forest$split.stats[[tree]]
result$splitStat[result$terminal] <- NA
}

result
}
2 changes: 1 addition & 1 deletion cpp_version/src/version.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#ifndef RANGER_VERSION
#define RANGER_VERSION "0.15.3"
#define RANGER_VERSION "0.15.4"
#endif
3 changes: 3 additions & 0 deletions man/ranger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/treeInfo.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 6 additions & 5 deletions src/Forest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode
init(loadDataFromFile(input_file), mtry, output_prefix, num_trees, seed, num_threads, importance_mode,
min_node_size, min_bucket, prediction_mode, sample_with_replacement, unordered_variable_names, memory_saving_splitting,
splitrule, predict_all, sample_fraction_vector, alpha, minprop, holdout, prediction_type, num_random_splits,
false, max_depth, regularization_factor, regularization_usedepth);
false, max_depth, regularization_factor, regularization_usedepth, false);

if (prediction_mode) {
loadFromFile(load_forest_filename);
Expand Down Expand Up @@ -140,15 +140,15 @@ void Forest::initR(std::unique_ptr<Data> input_data, uint mtry, uint num_trees,
std::vector<std::vector<size_t>>& manual_inbag, bool predict_all, bool keep_inbag,
std::vector<double>& sample_fraction, double alpha, double minprop, bool holdout, PredictionType prediction_type,
uint num_random_splits, bool order_snps, uint max_depth, const std::vector<double>& regularization_factor,
bool regularization_usedepth) {
bool regularization_usedepth, bool node_stats) {

this->verbose_out = verbose_out;

// Call other init function
init(std::move(input_data), mtry, "", num_trees, seed, num_threads, importance_mode, min_node_size, min_bucket,
prediction_mode, sample_with_replacement, unordered_variable_names, memory_saving_splitting, splitrule,
predict_all, sample_fraction, alpha, minprop, holdout, prediction_type, num_random_splits, order_snps, max_depth,
regularization_factor, regularization_usedepth);
regularization_factor, regularization_usedepth, node_stats);

// Set variables to be always considered for splitting
if (!always_split_variable_names.empty()) {
Expand Down Expand Up @@ -182,7 +182,7 @@ void Forest::init(std::unique_ptr<Data> input_data, uint mtry, std::string outpu
bool prediction_mode, bool sample_with_replacement, const std::vector<std::string>& unordered_variable_names,
bool memory_saving_splitting, SplitRule splitrule, bool predict_all, std::vector<double>& sample_fraction,
double alpha, double minprop, bool holdout, PredictionType prediction_type, uint num_random_splits, bool order_snps,
uint max_depth, const std::vector<double>& regularization_factor, bool regularization_usedepth) {
uint max_depth, const std::vector<double>& regularization_factor, bool regularization_usedepth, bool node_stats) {

// Initialize data with memmode
this->data = std::move(input_data);
Expand Down Expand Up @@ -224,6 +224,7 @@ void Forest::init(std::unique_ptr<Data> input_data, uint mtry, std::string outpu
this->max_depth = max_depth;
this->regularization_factor = regularization_factor;
this->regularization_usedepth = regularization_usedepth;
this->save_node_stats = node_stats;

// Set number of samples and variables
num_samples = data->getNumRows();
Expand Down Expand Up @@ -474,7 +475,7 @@ void Forest::grow() {
trees[i]->init(data.get(), mtry, num_samples, tree_seed, &deterministic_varIDs, tree_split_select_weights,
importance_mode, min_node_size, min_bucket, sample_with_replacement, memory_saving_splitting, splitrule, &case_weights,
tree_manual_inbag, keep_inbag, &sample_fraction, alpha, minprop, holdout, num_random_splits, max_depth,
&regularization_factor, regularization_usedepth, &split_varIDs_used);
&regularization_factor, regularization_usedepth, &split_varIDs_used, save_node_stats);
}

// Init variable importance
Expand Down
29 changes: 27 additions & 2 deletions src/Forest.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,15 @@ class Forest {
std::vector<double>& case_weights, std::vector<std::vector<size_t>>& manual_inbag, bool predict_all,
bool keep_inbag, std::vector<double>& sample_fraction, double alpha, double minprop, bool holdout,
PredictionType prediction_type, uint num_random_splits, bool order_snps, uint max_depth,
const std::vector<double>& regularization_factor, bool regularization_usedepth);
const std::vector<double>& regularization_factor, bool regularization_usedepth,
bool node_stats);
void init(std::unique_ptr<Data> input_data, uint mtry, std::string output_prefix,
uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size, uint min_bucket,
bool prediction_mode, bool sample_with_replacement, const std::vector<std::string>& unordered_variable_names,
bool memory_saving_splitting, SplitRule splitrule, bool predict_all, std::vector<double>& sample_fraction,
double alpha, double minprop, bool holdout, PredictionType prediction_type, uint num_random_splits,
bool order_snps, uint max_depth, const std::vector<double>& regularization_factor, bool regularization_usedepth);
bool order_snps, uint max_depth, const std::vector<double>& regularization_factor, bool regularization_usedepth,
bool node_stats);
virtual void initInternal() = 0;

// Grow or predict
Expand Down Expand Up @@ -142,6 +144,28 @@ class Forest {
const std::vector<std::vector<size_t>>& getSnpOrder() const {
return data->getSnpOrder();
}

std::vector<std::vector<size_t>> getNumSamplesNodes() {
std::vector<std::vector<size_t>> result;
for (auto& tree : trees) {
result.push_back(tree->getNumSamplesNodes());
}
return result;
}
std::vector<std::vector<double>> getNodePredictions() {
std::vector<std::vector<double>> result;
for (auto& tree : trees) {
result.push_back(tree->getNodePredictions());
}
return result;
}
std::vector<std::vector<double>> getSplitStats() {
std::vector<std::vector<double>> result;
for (auto& tree : trees) {
result.push_back(tree->getSplitStats());
}
return result;
}

protected:
void grow();
Expand Down Expand Up @@ -202,6 +226,7 @@ class Forest {
PredictionType prediction_type;
uint num_random_splits;
uint max_depth;
bool save_node_stats;

// MAXSTAT splitrule
double alpha;
Expand Down
Loading

0 comments on commit 9722742

Please sign in to comment.