From 52b797d520ce592eb11b63711622a481d29408bd Mon Sep 17 00:00:00 2001 From: Lorenzo Mec-iS Date: Mon, 20 Jan 2025 17:18:09 +0000 Subject: [PATCH] format --- src/ensemble/random_forest_classifier.rs | 62 +++++++++++++++++------- 1 file changed, 45 insertions(+), 17 deletions(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index f03c9cc7..19d75f38 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -55,11 +55,11 @@ use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; use crate::error::{Failed, FailedError}; +use crate::linalg::basic::arrays::MutArray; use crate::linalg::basic::arrays::{Array1, Array2}; +use crate::linalg::basic::matrix::DenseMatrix; use crate::numbers::basenum::Number; use crate::numbers::floatnum::FloatNumber; -use crate::linalg::basic::matrix::DenseMatrix; -use crate::linalg::basic::arrays::MutArray; use crate::rand_custom::get_rng_impl; use crate::tree::decision_tree_classifier::{ @@ -667,16 +667,15 @@ impl, Y: Array1 = vec![0, 0, 0, 0, 0, 1, 1, 1, 1, 1]; let forest = RandomForestClassifier::fit(&x, &y, Default::default()).unwrap(); @@ -858,12 +858,21 @@ mod tests { // Test probability sum for i in 0..10 { let row_sum: f64 = probas.get_row(i).sum(); - assert!((row_sum - 1.0).abs() < 1e-6, "Row probabilities should sum to 1"); + assert!( + (row_sum - 1.0).abs() < 1e-6, + "Row probabilities should sum to 1" + ); } // Test class prediction let predictions: Vec = (0..10) - .map(|i| if probas.get((i, 0)) > probas.get((i, 1)) { 0 } else { 1 }) + .map(|i| { + if probas.get((i, 0)) > probas.get((i, 1)) { + 0 + } else { + 1 + } + }) .collect(); let acc = accuracy(&y, &predictions); assert!(acc > 0.8, "Accuracy should be high for the training set"); @@ -871,23 +880,42 @@ mod tests { // Test probability values // These values are approximate and based on typical random forest behavior for i in 0..5 { - assert!(*probas.get((i, 0)) > 0.6, "Class 0 samples should have high probability for class 0"); - assert!(*probas.get((i, 1)) < 0.4, "Class 0 samples should have low probability for class 1"); + assert!( + *probas.get((i, 0)) > 0.6, + "Class 0 samples should have high probability for class 0" + ); + assert!( + *probas.get((i, 1)) < 0.4, + "Class 0 samples should have low probability for class 1" + ); } for i in 5..10 { - assert!(*probas.get((i, 1)) > 0.6, "Class 1 samples should have high probability for class 1"); - assert!(*probas.get((i, 0)) < 0.4, "Class 1 samples should have low probability for class 0"); + assert!( + *probas.get((i, 1)) > 0.6, + "Class 1 samples should have high probability for class 1" + ); + assert!( + *probas.get((i, 0)) < 0.4, + "Class 1 samples should have low probability for class 0" + ); } // Test with new data let x_new = DenseMatrix::from_2d_array(&[ - &[5.0, 3.4, 1.5, 0.2], // Should be close to class 0 - &[6.3, 3.3, 4.7, 1.6], // Should be close to class 1 - ]).unwrap(); + &[5.0, 3.4, 1.5, 0.2], // Should be close to class 0 + &[6.3, 3.3, 4.7, 1.6], // Should be close to class 1 + ]) + .unwrap(); let probas_new = forest.predict_proba(&x_new).unwrap(); assert_eq!(probas_new.shape(), (2, 2)); - assert!(probas_new.get((0, 0)) > probas_new.get((0, 1)), "First sample should be predicted as class 0"); - assert!(probas_new.get((1, 1)) > probas_new.get((1, 0)), "Second sample should be predicted as class 1"); + assert!( + probas_new.get((0, 0)) > probas_new.get((0, 1)), + "First sample should be predicted as class 0" + ); + assert!( + probas_new.get((1, 1)) > probas_new.get((1, 0)), + "Second sample should be predicted as class 1" + ); } #[cfg_attr(