Skip to content

Commit

Permalink
Merge pull request #721 from imbs-hl/class_specific_nodesize
Browse files Browse the repository at this point in the history
Allow vector min.node.size/min.bucket for class-wise limits
  • Loading branch information
mnwright authored May 16, 2024
2 parents 858bfda + d5b3e66 commit 8e58766
Show file tree
Hide file tree
Showing 19 changed files with 519 additions and 146 deletions.
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: ranger
Type: Package
Title: A Fast Implementation of Random Forests
Version: 0.16.0
Date: 2023-11-09
Version: 0.16.1
Date: 2024-05-15
Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb]
Maintainer: Marvin N. Wright <[email protected]>
Description: A fast implementation of Random Forests, particularly suited for high
Expand All @@ -19,7 +19,7 @@ Suggests:
survival,
testthat
Encoding: UTF-8
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
URL: http://imbs-hl.github.io/ranger/,
https://github.com/imbs-hl/ranger
BugReports: https://github.com/imbs-hl/ranger/issues
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@

# ranger 0.16.1
* Allow vector min.node.size and min.bucket for class-specific limits

# ranger 0.16.0
* New CRAN version

Expand Down
53 changes: 46 additions & 7 deletions R/ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@
##' @param importance Variable importance mode, one of 'none', 'impurity', 'impurity_corrected', 'permutation'. The 'impurity' measure is the Gini index for classification, the variance of the responses for regression and the sum of test statistics (see \code{splitrule}) for survival.
##' @param write.forest Save \code{ranger.forest} object, required for prediction. Set to \code{FALSE} to reduce memory usage if no prediction intended.
##' @param probability Grow a probability forest as in Malley et al. (2012).
##' @param min.node.size Minimal node size to split at. Default 1 for classification, 5 for regression, 3 for survival, and 10 for probability.
##' @param min.bucket Minimal terminal node size. No nodes smaller than this value can occur. Default 3 for survival and 1 for all other tree types.
##' @param min.node.size Minimal node size to split at. Default 1 for classification, 5 for regression, 3 for survival, and 10 for probability. For classification, this can be a vector of class-specific values.
##' @param min.bucket Minimal terminal node size. No nodes smaller than this value can occur. Default 3 for survival and 1 for all other tree types. For classification, this can be a vector of class-specific values.
##' @param max.depth Maximal tree depth. A value of NULL or 0 (the default) corresponds to unlimited depth, 1 to tree stumps (1 split per tree).
##' @param replace Sample with replacement.
##' @param sample.fraction Fraction of observations to sample. Default is 1 for sampling with replacement and 0.632 for sampling without replacement. For classification, this can be a vector of class-specific values.
Expand Down Expand Up @@ -359,6 +359,15 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
stop("Error: Unsupported type of dependent variable.")
}

## Number of levels
if (treetype %in% c(1, 9)) {
if (is.factor(y)) {
num_levels <- nlevels(y)
} else {
num_levels <- length(unique(y))
}
}

## Quantile prediction only for regression
if (quantreg && treetype != 3) {
stop("Error: Quantile prediction implemented only for regression outcomes.")
Expand Down Expand Up @@ -522,16 +531,46 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
## Minimum node size
if (is.null(min.node.size)) {
min.node.size <- 0
} else if (!is.numeric(min.node.size) || min.node.size < 0) {
stop("Error: Invalid value for min.node.size")
} else if (!is.numeric(min.node.size)) {
stop("Error: Invalid value for min.node.size.")
}
if (length(min.node.size) > 1) {
if (!(treetype %in% c(1, 9))) {
stop("Error: Invalid value for min.node.size. Vector values only valid for classification forests.")
}
if (any(min.node.size < 0)) {
stop("Error: Invalid value for min.node.size. Please give a nonnegative value or a vector of nonnegative values.")
}
if (length(min.node.size) != num_levels) {
stop("Error: Invalid value for min.node.size Expecting ", num_levels, " values, provided ", length(min.node.size), ".")
}
} else {
if (min.node.size < 0) {
stop("Error: Invalid value for min.node.size. Please give a nonnegative value or a vector of nonnegative values.")
}
}

## Minimum bucket size
if (is.null(min.bucket)) {
min.bucket <- 0
} else if (!is.numeric(min.bucket) || min.bucket < 0) {
} else if (!is.numeric(min.bucket)) {
stop("Error: Invalid value for min.bucket")
}
if (length(min.bucket) > 1) {
if (!(treetype %in% c(1, 9))) {
stop("Error: Invalid value for min.bucket Vector values only valid for classification forests.")
}
if (any(min.bucket < 0)) {
stop("Error: Invalid value for min.bucket Please give a nonnegative value or a vector of nonnegative values.")
}
if (length(min.bucket) != num_levels) {
stop("Error: Invalid value for min.bucket Expecting ", num_levels, " values, provided ", length(min.bucket), ".")
}
} else {
if (min.bucket < 0) {
stop("Error: Invalid value for min.bucket Please give a nonnegative value or a vector of nonnegative values.")
}
}

## Tree depth
if (is.null(max.depth)) {
Expand All @@ -554,8 +593,8 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
if (sum(sample.fraction) <= 0) {
stop("Error: Invalid value for sample.fraction. Sum of values must be >0.")
}
if (length(sample.fraction) != nlevels(y)) {
stop("Error: Invalid value for sample.fraction. Expecting ", nlevels(y), " values, provided ", length(sample.fraction), ".")
if (length(sample.fraction) != num_levels) {
stop("Error: Invalid value for sample.fraction. Expecting ", num_levels, " values, provided ", length(sample.fraction), ".")
}
if (!replace & any(sample.fraction * length(y) > table(y))) {
idx <- which(sample.fraction * length(y) > table(y))[1]
Expand Down
4 changes: 2 additions & 2 deletions man/ranger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 9 additions & 6 deletions src/Forest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
namespace ranger {

Forest::Forest() :
verbose_out(0), num_trees(DEFAULT_NUM_TREE), mtry(0), min_node_size(0), min_bucket(0), num_independent_variables(0), seed(0), num_samples(
verbose_out(0), num_trees(DEFAULT_NUM_TREE), mtry(0), min_node_size({0}), min_bucket({0}), num_independent_variables(0), seed(0), num_samples(
0), prediction_mode(false), memory_mode(MEM_DOUBLE), sample_with_replacement(true), memory_saving_splitting(
false), splitrule(DEFAULT_SPLITRULE), predict_all(false), keep_inbag(false), sample_fraction( { 1 }), holdout(
false), prediction_type(DEFAULT_PREDICTIONTYPE), num_random_splits(DEFAULT_NUM_RANDOM_SPLITS), max_depth(
Expand Down Expand Up @@ -62,6 +62,9 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode
if (!load_forest_filename.empty()) {
prediction_mode = true;
}

std::vector<uint> min_node_size_vector = { min_node_size };
std::vector<uint> min_bucket_vector = { min_bucket };

// Sample fraction default and convert to vector
if (sample_fraction == 0) {
Expand All @@ -79,7 +82,7 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode

// Call other init function
init(loadDataFromFile(input_file), mtry, output_prefix, num_trees, seed, num_threads, importance_mode,
min_node_size, min_bucket, prediction_mode, sample_with_replacement, unordered_variable_names, memory_saving_splitting,
min_node_size_vector, min_bucket_vector, prediction_mode, sample_with_replacement, unordered_variable_names, memory_saving_splitting,
splitrule, predict_all, sample_fraction_vector, alpha, minprop, holdout, prediction_type, num_random_splits,
false, max_depth, regularization_factor, regularization_usedepth, false);

Expand Down Expand Up @@ -133,7 +136,7 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode
// #nocov end

void Forest::initR(std::unique_ptr<Data> input_data, uint mtry, uint num_trees, std::ostream* verbose_out, uint seed,
uint num_threads, ImportanceMode importance_mode, uint min_node_size, uint min_bucket,
uint num_threads, ImportanceMode importance_mode, std::vector<uint>& min_node_size, std::vector<uint>& min_bucket,
std::vector<std::vector<double>>& split_select_weights, const std::vector<std::string>& always_split_variable_names,
bool prediction_mode, bool sample_with_replacement, const std::vector<std::string>& unordered_variable_names,
bool memory_saving_splitting, SplitRule splitrule, std::vector<double>& case_weights,
Expand Down Expand Up @@ -178,7 +181,7 @@ void Forest::initR(std::unique_ptr<Data> input_data, uint mtry, uint num_trees,
}

void Forest::init(std::unique_ptr<Data> input_data, uint mtry, std::string output_prefix,
uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size, uint min_bucket,
uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, std::vector<uint>& min_node_size, std::vector<uint>& min_bucket,
bool prediction_mode, bool sample_with_replacement, const std::vector<std::string>& unordered_variable_names,
bool memory_saving_splitting, SplitRule splitrule, bool predict_all, std::vector<double>& sample_fraction,
double alpha, double minprop, bool holdout, PredictionType prediction_type, uint num_random_splits, bool order_snps,
Expand Down Expand Up @@ -323,7 +326,7 @@ void Forest::writeOutput() {
*verbose_out << "Sample size: " << num_samples << std::endl;
*verbose_out << "Number of independent variables: " << num_independent_variables << std::endl;
*verbose_out << "Mtry: " << mtry << std::endl;
*verbose_out << "Target node size: " << min_node_size << std::endl;
*verbose_out << "Target node size: " << min_node_size[0] << std::endl;
*verbose_out << "Variable importance mode: " << importance_mode << std::endl;
*verbose_out << "Memory mode: " << memory_mode << std::endl;
*verbose_out << "Seed: " << seed << std::endl;
Expand Down Expand Up @@ -473,7 +476,7 @@ void Forest::grow() {
}

trees[i]->init(data.get(), mtry, num_samples, tree_seed, &deterministic_varIDs, tree_split_select_weights,
importance_mode, min_node_size, min_bucket, sample_with_replacement, memory_saving_splitting, splitrule, &case_weights,
importance_mode, &min_node_size, &min_bucket, sample_with_replacement, memory_saving_splitting, splitrule, &case_weights,
tree_manual_inbag, keep_inbag, &sample_fraction, alpha, minprop, holdout, num_random_splits, max_depth,
&regularization_factor, regularization_usedepth, &split_varIDs_used, save_node_stats);
}
Expand Down
12 changes: 6 additions & 6 deletions src/Forest.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Forest {
bool holdout, PredictionType prediction_type, uint num_random_splits, uint max_depth,
const std::vector<double>& regularization_factor, bool regularization_usedepth);
void initR(std::unique_ptr<Data> input_data, uint mtry, uint num_trees, std::ostream* verbose_out, uint seed,
uint num_threads, ImportanceMode importance_mode, uint min_node_size, uint min_bucket,
uint num_threads, ImportanceMode importance_mode, std::vector<uint>& min_node_size, std::vector<uint>& min_bucket,
std::vector<std::vector<double>>& split_select_weights,
const std::vector<std::string>& always_split_variable_names, bool prediction_mode, bool sample_with_replacement,
const std::vector<std::string>& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule,
Expand All @@ -58,7 +58,7 @@ class Forest {
const std::vector<double>& regularization_factor, bool regularization_usedepth,
bool node_stats);
void init(std::unique_ptr<Data> input_data, uint mtry, std::string output_prefix,
uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size, uint min_bucket,
uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, std::vector<uint>& min_node_size, std::vector<uint>& min_bucket,
bool prediction_mode, bool sample_with_replacement, const std::vector<std::string>& unordered_variable_names,
bool memory_saving_splitting, SplitRule splitrule, bool predict_all, std::vector<double>& sample_fraction,
double alpha, double minprop, bool holdout, PredictionType prediction_type, uint num_random_splits,
Expand Down Expand Up @@ -119,10 +119,10 @@ class Forest {
uint getMtry() const {
return mtry;
}
uint getMinNodeSize() const {
const std::vector<uint>& getMinNodeSize() const {
return min_node_size;
}
uint getMinBucket() const {
const std::vector<uint>& getMinBucket() const {
return min_bucket;
}
size_t getNumIndependentVariables() const {
Expand Down Expand Up @@ -209,8 +209,8 @@ class Forest {
std::vector<std::string> dependent_variable_names; // time,status for survival
size_t num_trees;
uint mtry;
uint min_node_size;
uint min_bucket;
std::vector<uint> min_node_size;
std::vector<uint> min_bucket;
size_t num_independent_variables;
uint seed;
size_t num_samples;
Expand Down
8 changes: 4 additions & 4 deletions src/ForestClassification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ void ForestClassification::initInternal() {
}

// Set minimal node size
if (min_node_size == 0) {
min_node_size = DEFAULT_MIN_NODE_SIZE_CLASSIFICATION;
if (min_node_size.size() == 1 && min_node_size[0] == 0) {
min_node_size[0] = DEFAULT_MIN_NODE_SIZE_CLASSIFICATION;
}

// Set minimal bucket size
if (min_bucket == 0) {
min_bucket = DEFAULT_MIN_BUCKET;
if (min_bucket.size() == 1 && min_bucket[0] == 0) {
min_bucket[0] = DEFAULT_MIN_BUCKET;
}

// Create class_values and response_classIDs
Expand Down
8 changes: 4 additions & 4 deletions src/ForestProbability.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ void ForestProbability::initInternal() {
}

// Set minimal node size
if (min_node_size == 0) {
min_node_size = DEFAULT_MIN_NODE_SIZE_PROBABILITY;
if (min_node_size.size() == 1 && min_node_size[0] == 0) {
min_node_size[0] = DEFAULT_MIN_NODE_SIZE_PROBABILITY;
}

// Set minimal bucket size
if (min_bucket == 0) {
min_bucket = DEFAULT_MIN_BUCKET;
if (min_bucket.size() == 1 && min_bucket[0] == 0) {
min_bucket[0] = DEFAULT_MIN_BUCKET;
}

// Create class_values and response_classIDs
Expand Down
8 changes: 4 additions & 4 deletions src/ForestRegression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ void ForestRegression::initInternal() {
}

// Set minimal node size
if (min_node_size == 0) {
min_node_size = DEFAULT_MIN_NODE_SIZE_REGRESSION;
if (min_node_size.size() == 1 && min_node_size[0] == 0) {
min_node_size[0] = DEFAULT_MIN_NODE_SIZE_REGRESSION;
}

// Set minimal bucket size
if (min_bucket == 0) {
min_bucket = DEFAULT_MIN_BUCKET;
if (min_bucket.size() == 1 && min_bucket[0] == 0) {
min_bucket[0] = DEFAULT_MIN_BUCKET;
}

// Error if beta splitrule used with data outside of [0,1]
Expand Down
8 changes: 4 additions & 4 deletions src/ForestSurvival.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ void ForestSurvival::initInternal() {
}

// Set minimal node size
if (min_node_size == 0) {
min_node_size = DEFAULT_MIN_NODE_SIZE_SURVIVAL;
if (min_node_size.size() == 1 && min_node_size[0] == 0) {
min_node_size[0] = DEFAULT_MIN_NODE_SIZE_SURVIVAL;
}

// Set minimal bucket size
if (min_bucket == 0) {
min_bucket = DEFAULT_MIN_BUCKET_SURVIVAL;
if (min_bucket.size() == 1 && min_bucket[0] == 0) {
min_bucket[0] = DEFAULT_MIN_BUCKET_SURVIVAL;
}

// Sort data if extratrees and not memory saving mode
Expand Down
Loading

0 comments on commit 8e58766

Please sign in to comment.