Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow vector min.node.size/min.bucket for class-wise limits #721

Merged
merged 2 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: ranger
Type: Package
Title: A Fast Implementation of Random Forests
Version: 0.16.0
Date: 2023-11-09
Version: 0.16.1
Date: 2024-05-15
Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb]
Maintainer: Marvin N. Wright <[email protected]>
Description: A fast implementation of Random Forests, particularly suited for high
Expand All @@ -19,7 +19,7 @@ Suggests:
survival,
testthat
Encoding: UTF-8
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
URL: http://imbs-hl.github.io/ranger/,
https://github.com/imbs-hl/ranger
BugReports: https://github.com/imbs-hl/ranger/issues
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@

# ranger 0.16.1
* Allow vector min.node.size and min.bucket for class-specific limits

# ranger 0.16.0
* New CRAN version

Expand Down
53 changes: 46 additions & 7 deletions R/ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@
##' @param importance Variable importance mode, one of 'none', 'impurity', 'impurity_corrected', 'permutation'. The 'impurity' measure is the Gini index for classification, the variance of the responses for regression and the sum of test statistics (see \code{splitrule}) for survival.
##' @param write.forest Save \code{ranger.forest} object, required for prediction. Set to \code{FALSE} to reduce memory usage if no prediction intended.
##' @param probability Grow a probability forest as in Malley et al. (2012).
##' @param min.node.size Minimal node size to split at. Default 1 for classification, 5 for regression, 3 for survival, and 10 for probability.
##' @param min.bucket Minimal terminal node size. No nodes smaller than this value can occur. Default 3 for survival and 1 for all other tree types.
##' @param min.node.size Minimal node size to split at. Default 1 for classification, 5 for regression, 3 for survival, and 10 for probability. For classification, this can be a vector of class-specific values.
##' @param min.bucket Minimal terminal node size. No nodes smaller than this value can occur. Default 3 for survival and 1 for all other tree types. For classification, this can be a vector of class-specific values.
##' @param max.depth Maximal tree depth. A value of NULL or 0 (the default) corresponds to unlimited depth, 1 to tree stumps (1 split per tree).
##' @param replace Sample with replacement.
##' @param sample.fraction Fraction of observations to sample. Default is 1 for sampling with replacement and 0.632 for sampling without replacement. For classification, this can be a vector of class-specific values.
Expand Down Expand Up @@ -359,6 +359,15 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
stop("Error: Unsupported type of dependent variable.")
}

## Number of levels
if (treetype %in% c(1, 9)) {
if (is.factor(y)) {
num_levels <- nlevels(y)
} else {
num_levels <- length(unique(y))
}
}

## Quantile prediction only for regression
if (quantreg && treetype != 3) {
stop("Error: Quantile prediction implemented only for regression outcomes.")
Expand Down Expand Up @@ -522,16 +531,46 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
## Minimum node size
if (is.null(min.node.size)) {
min.node.size <- 0
} else if (!is.numeric(min.node.size) || min.node.size < 0) {
stop("Error: Invalid value for min.node.size")
} else if (!is.numeric(min.node.size)) {
stop("Error: Invalid value for min.node.size.")
}
if (length(min.node.size) > 1) {
if (!(treetype %in% c(1, 9))) {
stop("Error: Invalid value for min.node.size. Vector values only valid for classification forests.")
}
if (any(min.node.size < 0)) {
stop("Error: Invalid value for min.node.size. Please give a nonnegative value or a vector of nonnegative values.")
}
if (length(min.node.size) != num_levels) {
stop("Error: Invalid value for min.node.size Expecting ", num_levels, " values, provided ", length(min.node.size), ".")
}
} else {
if (min.node.size < 0) {
stop("Error: Invalid value for min.node.size. Please give a nonnegative value or a vector of nonnegative values.")
}
}

## Minimum bucket size
if (is.null(min.bucket)) {
min.bucket <- 0
} else if (!is.numeric(min.bucket) || min.bucket < 0) {
} else if (!is.numeric(min.bucket)) {
stop("Error: Invalid value for min.bucket")
}
if (length(min.bucket) > 1) {
if (!(treetype %in% c(1, 9))) {
stop("Error: Invalid value for min.bucket Vector values only valid for classification forests.")
}
if (any(min.bucket < 0)) {
stop("Error: Invalid value for min.bucket Please give a nonnegative value or a vector of nonnegative values.")
}
if (length(min.bucket) != num_levels) {
stop("Error: Invalid value for min.bucket Expecting ", num_levels, " values, provided ", length(min.bucket), ".")
}
} else {
if (min.bucket < 0) {
stop("Error: Invalid value for min.bucket Please give a nonnegative value or a vector of nonnegative values.")
}
}

## Tree depth
if (is.null(max.depth)) {
Expand All @@ -554,8 +593,8 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
if (sum(sample.fraction) <= 0) {
stop("Error: Invalid value for sample.fraction. Sum of values must be >0.")
}
if (length(sample.fraction) != nlevels(y)) {
stop("Error: Invalid value for sample.fraction. Expecting ", nlevels(y), " values, provided ", length(sample.fraction), ".")
if (length(sample.fraction) != num_levels) {
stop("Error: Invalid value for sample.fraction. Expecting ", num_levels, " values, provided ", length(sample.fraction), ".")
}
if (!replace & any(sample.fraction * length(y) > table(y))) {
idx <- which(sample.fraction * length(y) > table(y))[1]
Expand Down
4 changes: 2 additions & 2 deletions man/ranger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 9 additions & 6 deletions src/Forest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
namespace ranger {

Forest::Forest() :
verbose_out(0), num_trees(DEFAULT_NUM_TREE), mtry(0), min_node_size(0), min_bucket(0), num_independent_variables(0), seed(0), num_samples(
verbose_out(0), num_trees(DEFAULT_NUM_TREE), mtry(0), min_node_size({0}), min_bucket({0}), num_independent_variables(0), seed(0), num_samples(
0), prediction_mode(false), memory_mode(MEM_DOUBLE), sample_with_replacement(true), memory_saving_splitting(
false), splitrule(DEFAULT_SPLITRULE), predict_all(false), keep_inbag(false), sample_fraction( { 1 }), holdout(
false), prediction_type(DEFAULT_PREDICTIONTYPE), num_random_splits(DEFAULT_NUM_RANDOM_SPLITS), max_depth(
Expand Down Expand Up @@ -62,6 +62,9 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode
if (!load_forest_filename.empty()) {
prediction_mode = true;
}

std::vector<uint> min_node_size_vector = { min_node_size };
std::vector<uint> min_bucket_vector = { min_bucket };

// Sample fraction default and convert to vector
if (sample_fraction == 0) {
Expand All @@ -79,7 +82,7 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode

// Call other init function
init(loadDataFromFile(input_file), mtry, output_prefix, num_trees, seed, num_threads, importance_mode,
min_node_size, min_bucket, prediction_mode, sample_with_replacement, unordered_variable_names, memory_saving_splitting,
min_node_size_vector, min_bucket_vector, prediction_mode, sample_with_replacement, unordered_variable_names, memory_saving_splitting,
splitrule, predict_all, sample_fraction_vector, alpha, minprop, holdout, prediction_type, num_random_splits,
false, max_depth, regularization_factor, regularization_usedepth, false);

Expand Down Expand Up @@ -133,7 +136,7 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode
// #nocov end

void Forest::initR(std::unique_ptr<Data> input_data, uint mtry, uint num_trees, std::ostream* verbose_out, uint seed,
uint num_threads, ImportanceMode importance_mode, uint min_node_size, uint min_bucket,
uint num_threads, ImportanceMode importance_mode, std::vector<uint>& min_node_size, std::vector<uint>& min_bucket,
std::vector<std::vector<double>>& split_select_weights, const std::vector<std::string>& always_split_variable_names,
bool prediction_mode, bool sample_with_replacement, const std::vector<std::string>& unordered_variable_names,
bool memory_saving_splitting, SplitRule splitrule, std::vector<double>& case_weights,
Expand Down Expand Up @@ -178,7 +181,7 @@ void Forest::initR(std::unique_ptr<Data> input_data, uint mtry, uint num_trees,
}

void Forest::init(std::unique_ptr<Data> input_data, uint mtry, std::string output_prefix,
uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size, uint min_bucket,
uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, std::vector<uint>& min_node_size, std::vector<uint>& min_bucket,
bool prediction_mode, bool sample_with_replacement, const std::vector<std::string>& unordered_variable_names,
bool memory_saving_splitting, SplitRule splitrule, bool predict_all, std::vector<double>& sample_fraction,
double alpha, double minprop, bool holdout, PredictionType prediction_type, uint num_random_splits, bool order_snps,
Expand Down Expand Up @@ -323,7 +326,7 @@ void Forest::writeOutput() {
*verbose_out << "Sample size: " << num_samples << std::endl;
*verbose_out << "Number of independent variables: " << num_independent_variables << std::endl;
*verbose_out << "Mtry: " << mtry << std::endl;
*verbose_out << "Target node size: " << min_node_size << std::endl;
*verbose_out << "Target node size: " << min_node_size[0] << std::endl;
*verbose_out << "Variable importance mode: " << importance_mode << std::endl;
*verbose_out << "Memory mode: " << memory_mode << std::endl;
*verbose_out << "Seed: " << seed << std::endl;
Expand Down Expand Up @@ -473,7 +476,7 @@ void Forest::grow() {
}

trees[i]->init(data.get(), mtry, num_samples, tree_seed, &deterministic_varIDs, tree_split_select_weights,
importance_mode, min_node_size, min_bucket, sample_with_replacement, memory_saving_splitting, splitrule, &case_weights,
importance_mode, &min_node_size, &min_bucket, sample_with_replacement, memory_saving_splitting, splitrule, &case_weights,
tree_manual_inbag, keep_inbag, &sample_fraction, alpha, minprop, holdout, num_random_splits, max_depth,
&regularization_factor, regularization_usedepth, &split_varIDs_used, save_node_stats);
}
Expand Down
12 changes: 6 additions & 6 deletions src/Forest.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Forest {
bool holdout, PredictionType prediction_type, uint num_random_splits, uint max_depth,
const std::vector<double>& regularization_factor, bool regularization_usedepth);
void initR(std::unique_ptr<Data> input_data, uint mtry, uint num_trees, std::ostream* verbose_out, uint seed,
uint num_threads, ImportanceMode importance_mode, uint min_node_size, uint min_bucket,
uint num_threads, ImportanceMode importance_mode, std::vector<uint>& min_node_size, std::vector<uint>& min_bucket,
std::vector<std::vector<double>>& split_select_weights,
const std::vector<std::string>& always_split_variable_names, bool prediction_mode, bool sample_with_replacement,
const std::vector<std::string>& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule,
Expand All @@ -58,7 +58,7 @@ class Forest {
const std::vector<double>& regularization_factor, bool regularization_usedepth,
bool node_stats);
void init(std::unique_ptr<Data> input_data, uint mtry, std::string output_prefix,
uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size, uint min_bucket,
uint num_trees, uint seed, uint num_threads, ImportanceMode importance_mode, std::vector<uint>& min_node_size, std::vector<uint>& min_bucket,
bool prediction_mode, bool sample_with_replacement, const std::vector<std::string>& unordered_variable_names,
bool memory_saving_splitting, SplitRule splitrule, bool predict_all, std::vector<double>& sample_fraction,
double alpha, double minprop, bool holdout, PredictionType prediction_type, uint num_random_splits,
Expand Down Expand Up @@ -119,10 +119,10 @@ class Forest {
uint getMtry() const {
return mtry;
}
uint getMinNodeSize() const {
const std::vector<uint>& getMinNodeSize() const {
return min_node_size;
}
uint getMinBucket() const {
const std::vector<uint>& getMinBucket() const {
return min_bucket;
}
size_t getNumIndependentVariables() const {
Expand Down Expand Up @@ -209,8 +209,8 @@ class Forest {
std::vector<std::string> dependent_variable_names; // time,status for survival
size_t num_trees;
uint mtry;
uint min_node_size;
uint min_bucket;
std::vector<uint> min_node_size;
std::vector<uint> min_bucket;
size_t num_independent_variables;
uint seed;
size_t num_samples;
Expand Down
8 changes: 4 additions & 4 deletions src/ForestClassification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ void ForestClassification::initInternal() {
}

// Set minimal node size
if (min_node_size == 0) {
min_node_size = DEFAULT_MIN_NODE_SIZE_CLASSIFICATION;
if (min_node_size.size() == 1 && min_node_size[0] == 0) {
min_node_size[0] = DEFAULT_MIN_NODE_SIZE_CLASSIFICATION;
}

// Set minimal bucket size
if (min_bucket == 0) {
min_bucket = DEFAULT_MIN_BUCKET;
if (min_bucket.size() == 1 && min_bucket[0] == 0) {
min_bucket[0] = DEFAULT_MIN_BUCKET;
}

// Create class_values and response_classIDs
Expand Down
8 changes: 4 additions & 4 deletions src/ForestProbability.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ void ForestProbability::initInternal() {
}

// Set minimal node size
if (min_node_size == 0) {
min_node_size = DEFAULT_MIN_NODE_SIZE_PROBABILITY;
if (min_node_size.size() == 1 && min_node_size[0] == 0) {
min_node_size[0] = DEFAULT_MIN_NODE_SIZE_PROBABILITY;
}

// Set minimal bucket size
if (min_bucket == 0) {
min_bucket = DEFAULT_MIN_BUCKET;
if (min_bucket.size() == 1 && min_bucket[0] == 0) {
min_bucket[0] = DEFAULT_MIN_BUCKET;
}

// Create class_values and response_classIDs
Expand Down
8 changes: 4 additions & 4 deletions src/ForestRegression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ void ForestRegression::initInternal() {
}

// Set minimal node size
if (min_node_size == 0) {
min_node_size = DEFAULT_MIN_NODE_SIZE_REGRESSION;
if (min_node_size.size() == 1 && min_node_size[0] == 0) {
min_node_size[0] = DEFAULT_MIN_NODE_SIZE_REGRESSION;
}

// Set minimal bucket size
if (min_bucket == 0) {
min_bucket = DEFAULT_MIN_BUCKET;
if (min_bucket.size() == 1 && min_bucket[0] == 0) {
min_bucket[0] = DEFAULT_MIN_BUCKET;
}

// Error if beta splitrule used with data outside of [0,1]
Expand Down
8 changes: 4 additions & 4 deletions src/ForestSurvival.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ void ForestSurvival::initInternal() {
}

// Set minimal node size
if (min_node_size == 0) {
min_node_size = DEFAULT_MIN_NODE_SIZE_SURVIVAL;
if (min_node_size.size() == 1 && min_node_size[0] == 0) {
min_node_size[0] = DEFAULT_MIN_NODE_SIZE_SURVIVAL;
}

// Set minimal bucket size
if (min_bucket == 0) {
min_bucket = DEFAULT_MIN_BUCKET_SURVIVAL;
if (min_bucket.size() == 1 && min_bucket[0] == 0) {
min_bucket[0] = DEFAULT_MIN_BUCKET_SURVIVAL;
}

// Sort data if extratrees and not memory saving mode
Expand Down
Loading
Loading