diff --git a/rten-simd/src/isa_detection.rs b/rten-simd/src/isa_detection.rs index 38ccbb15..65d398bb 100644 --- a/rten-simd/src/isa_detection.rs +++ b/rten-simd/src/isa_detection.rs @@ -3,7 +3,7 @@ /// Functions for reading system info on macOS. #[cfg(target_os = "macos")] #[allow(unused)] -mod macos { +pub mod macos { /// Detect availability of AVX-512 on macOS, where `is_x86_feature_detected` /// can return false even if AVX-512 is available. /// diff --git a/src/threading.rs b/src/threading.rs index bf92c96f..66d47dee 100644 --- a/src/threading.rs +++ b/src/threading.rs @@ -24,6 +24,24 @@ impl ThreadPool { } } +/// Return the optimal number of cores to use for maximum performance. +/// +/// This may be less than the total number of cores on systems with heterogenous +/// cores (eg. a mix of performance and efficiency). +fn optimal_core_count() -> u32 { + let mut core_count = num_cpus::get_physical().max(1) as u32; + + #[cfg(target_os = "macos")] + { + use rten_simd::isa_detection::macos::sysctl_int; + if let Ok(perf_core_count) = sysctl_int(c"hw.perflevel0.physicalcpu") { + core_count = core_count.clamp(1, perf_core_count as u32); + } + } + + core_count +} + /// Return the [Rayon][rayon] thread pool which is used to execute RTen models. /// /// This differs from Rayon's default global thread pool in that it is tuned for @@ -41,12 +59,12 @@ impl ThreadPool { pub fn thread_pool() -> &'static ThreadPool { static THREAD_POOL: OnceLock = OnceLock::new(); THREAD_POOL.get_or_init(|| { - let physical_cpus = num_cpus::get_physical(); + let physical_cpus = optimal_core_count(); let num_threads = if let Some(threads_var) = env::var_os("RTEN_NUM_THREADS") { - let requested_threads: Result = threads_var.to_string_lossy().parse(); + let requested_threads: Result = threads_var.to_string_lossy().parse(); match requested_threads { - Ok(n_threads) => n_threads.clamp(1, num_cpus::get()), + Ok(n_threads) => n_threads.clamp(1, num_cpus::get() as u32), Err(_) => physical_cpus, } } else { @@ -54,10 +72,22 @@ pub fn thread_pool() -> &'static ThreadPool { }; let pool = rayon::ThreadPoolBuilder::new() - .num_threads(num_threads) + .num_threads(num_threads as usize) .thread_name(|index| format!("rten-{}", index)) .build(); ThreadPool { pool: pool.ok() } }) } + +#[cfg(test)] +mod tests { + use super::optimal_core_count; + + #[test] + fn test_optimal_core_count() { + let max_cores = num_cpus::get_physical() as u32; + let opt_cores = optimal_core_count(); + assert!(opt_cores >= 1 && opt_cores <= max_cores); + } +}