From 75c24aeda708a8970daa3cca5c04d3122bb182f7 Mon Sep 17 00:00:00 2001 From: Valerio Besozzi Date: Wed, 15 Nov 2023 15:10:17 +0100 Subject: [PATCH] Started refactoring on par_reduce --- src/templates/map.rs | 4 ++-- src/templates/misc.rs | 16 ++++++------- src/thread_pool/mod.rs | 51 +++++++++++++++++++++++++++++++++++++---- src/thread_pool/test.rs | 19 +++++++++++++-- 4 files changed, 74 insertions(+), 16 deletions(-) diff --git a/src/templates/map.rs b/src/templates/map.rs index 92d09164..57af6b9e 100644 --- a/src/templates/map.rs +++ b/src/templates/map.rs @@ -357,7 +357,7 @@ where F: FnOnce(TKey, Vec) -> (TKey, TReduce) + Send + Copy, { fn run(&mut self, input: TInIter) -> Option { - let res: TOutIter = self.threadpool.par_reduce(input, self.f).collect(); + let res: TOutIter = self.threadpool.par_reduce_by_key(input, self.f).collect(); Some(res) } fn number_of_replicas(&self) -> usize { @@ -489,7 +489,7 @@ where F: FnOnce(TKey, Vec) -> (TKey, TReduce) + Send + Copy, { fn run(&mut self, input: TInIter) -> Option { - let res: TOutIter = self.threadpool.par_reduce(input, self.f).collect(); + let res: TOutIter = self.threadpool.par_reduce_by_key(input, self.f).collect(); Some(res) } fn number_of_replicas(&self) -> usize { diff --git a/src/templates/misc.rs b/src/templates/misc.rs index 8d4baa0c..345d87fd 100644 --- a/src/templates/misc.rs +++ b/src/templates/misc.rs @@ -19,7 +19,7 @@ where { /// Creates a new source from any type that implements the `Iterator` trait. /// The source will terminate when the iterator is exhausted. - /// + /// /// # Arguments /// * `iterator` - Type that implements the [`Iterator`] trait /// and represents the stream of data we want emit. @@ -123,7 +123,7 @@ where /// /// # Arguments /// * `chunk_size` - Number of elements for each chunk. - /// + /// /// # Examples /// Given a stream of numbers, we create a pipeline with a splitter that /// create vectors of two elements each. @@ -204,7 +204,7 @@ where T: Send + 'static + Clone, { /// Creates a new aggregator node. - /// + /// /// # Arguments /// * `chunk_size` - Number of elements for each chunk. /// @@ -232,7 +232,7 @@ where } /// Creates a new aggregator node with 'n_replicas' replicas of the same node. - /// + /// /// # Arguments /// * `n_replicas` - Number of replicas. /// * `chunk_size` - Number of elements for each chunk. @@ -291,7 +291,7 @@ where F: FnMut(T) -> U + Send + 'static + Clone, { /// Creates a new sequential node. - /// + /// /// # Arguments /// * `f` - Function name or lambda function that specify the logic /// of this node. @@ -334,7 +334,7 @@ where F: FnMut(T) -> U + Send + 'static + Clone, { /// Creates a new parallel node. - /// + /// /// # Arguments /// * `n_replicas` - Number of replicas. /// * `f` - Function name or lambda function that specify the logic @@ -380,7 +380,7 @@ where F: FnMut(&T) -> bool + Send + 'static + Clone, { /// Creates a new filter node. - /// + /// /// # Arguments /// * `f` - Function name or lambda function that represent the predicate /// function we want to apply. @@ -561,7 +561,7 @@ where T: Send + 'static + Clone, { /// Creates a new ordered aggregator node - /// + /// /// # Arguments /// * `chunk_size` - Number of elements for each chunk. pub fn build(chunk_size: usize) -> impl InOut> { diff --git a/src/thread_pool/mod.rs b/src/thread_pool/mod.rs index ac7f893c..0ba4b22c 100644 --- a/src/thread_pool/mod.rs +++ b/src/thread_pool/mod.rs @@ -371,7 +371,7 @@ impl ThreadPool { }); }); - drop(arc_tx); + drop(arc_tx); // Refactoring? let mut disconnected = false; @@ -442,7 +442,7 @@ impl ThreadPool { Iter: IntoIterator, { let map = self.par_map(iter, f); - self.par_reduce(map, reduce) + self.par_reduce_by_key(map, reduce) } /// Reduces in parallel the elements of an iterator `iter` by the function `f`. @@ -468,12 +468,16 @@ impl ThreadPool { /// vec.push((i % 10, i)); /// } /// - /// let res: Vec<(i32, i32)> = pool.par_reduce(vec, |k, v| -> (i32, i32) { + /// let res: Vec<(i32, i32)> = pool.par_reduce_by_key(vec, |k, v| -> (i32, i32) { /// (k, v.iter().sum()) /// }).collect(); /// assert_eq!(res.len(), 10); /// ``` - pub fn par_reduce(&mut self, iter: Iter, f: F) -> impl Iterator + pub fn par_reduce_by_key( + &mut self, + iter: Iter, + f: F, + ) -> impl Iterator where ::Item: Send, K: Send + Ord + 'static, @@ -491,6 +495,45 @@ impl ThreadPool { self.par_map(ordered_map, move |(k, v)| f(k, v)) } + /// Reduce + /// + pub fn par_reduce(&mut self, iter: Iter, f: F) -> V + where + ::Item: Send, + V: Send + 'static, + F: FnOnce(V, V) -> V + Send + Copy + Sync, + Iter: IntoIterator, + { + let mut data: Vec = iter.into_iter().collect(); + + while data.len() != 1 { + let mut tmp = Vec::new(); + let mut num_proc = self.num_workers; + + while data.len() < 2 * num_proc { + num_proc -= 1; + } + let mut counter = 0; + + while !data.is_empty() { + counter %= num_proc; + tmp.push((counter, data.pop().unwrap())); + counter += 1; + } + + data = self + .par_reduce_by_key(tmp, |k, v| { + (k, v.into_iter().reduce(|a, b| f(a, b)).unwrap()) + }) + .collect::>() + .into_iter() + .map(|(_a, b)| b) + .collect(); + } + data.pop().unwrap() + + } + /// Create a new scope to execute jobs on other threads. /// The function passed to this method will be provided with a [`Scope`] object, /// which can be used to spawn new jobs through the [`Scope::execute`] method. diff --git a/src/thread_pool/test.rs b/src/thread_pool/test.rs index 742f8de9..e2ea95f5 100644 --- a/src/thread_pool/test.rs +++ b/src/thread_pool/test.rs @@ -159,12 +159,27 @@ fn test_par_reduce() { } let res: Vec<(i32, i32)> = pool - .par_reduce(vec, |k, v| -> (i32, i32) { (k, v.iter().sum()) }) + .par_reduce_by_key(vec, |k, v| -> (i32, i32) { (k, v.iter().sum()) }) .collect(); assert_eq!(res.len(), 10); } +#[test] +#[serial] +fn test_new_reduce() { + let mut pool = ThreadPool::new(); + + let mut vec = Vec::new(); + for _i in 0..130 { + vec.push(1); + } + + let res = pool.par_reduce(vec, |a, b| a + b); + + assert_eq!(res, 130); +} + #[test] #[serial] fn test_par_map_reduce_seq() { @@ -178,7 +193,7 @@ fn test_par_map_reduce_seq() { } let res = tp.par_map(vec, |el| -> (i32, i32) { (el, 1) }); - let res = tp.par_reduce(res, |k, v| (k, v.iter().sum::())); + let res = tp.par_reduce_by_key(res, |k, v| (k, v.iter().sum::())); let mut check = true; for (k, v) in res {