From 00ca59aaa78ba815fe15937fb09f508e4934270c Mon Sep 17 00:00:00 2001 From: zhangli20 Date: Thu, 26 Dec 2024 19:19:31 +0800 Subject: [PATCH] supports falling back hash join to sort merge join when hash table is too big. --- .github/workflows/tpcds.yml | 1 - native-engine/blaze-jni-bridge/src/conf.rs | 32 +- .../src/algorithm/mod.rs | 4 +- .../{rdx_tournament_tree.rs => rdx_queue.rs} | 22 +- .../src/algorithm/{rdxsort.rs => rdx_sort.rs} | 0 .../datafusion-ext-commons/src/lib.rs | 18 +- .../datafusion-ext-plans/src/agg/agg_ctx.rs | 13 +- .../datafusion-ext-plans/src/agg/agg_table.rs | 10 +- .../src/broadcast_join_build_hash_map_exec.rs | 119 ++++++-- .../src/broadcast_join_exec.rs | 278 ++++++++++++------ .../src/common/execution_context.rs | 24 +- .../datafusion-ext-plans/src/common/mod.rs | 2 + .../src/common/offsetted.rs | 192 ++++++++++++ .../src/common/stream_exec.rs | 72 +++++ .../src/joins/bhj/full_join.rs | 22 +- .../src/joins/join_hash_map.rs | 107 ++++--- .../datafusion-ext-plans/src/joins/test.rs | 2 + .../datafusion-ext-plans/src/memmgr/spill.rs | 27 ++ .../src/shuffle/buffered_data.rs | 203 ++++++------- .../datafusion-ext-plans/src/shuffle/mod.rs | 9 +- .../src/shuffle/sort_repartitioner.rs | 107 ++----- .../datafusion-ext-plans/src/sort_exec.rs | 21 +- .../src/sort_merge_join_exec.rs | 32 +- .../org/apache/spark/sql/blaze/BlazeConf.java | 11 +- .../apache/spark/sql/blaze/NativeHelper.scala | 1 + .../blaze/plan/NativeBroadcastJoinBase.scala | 1 + .../plan/NativeShuffledHashJoinBase.scala | 11 +- 27 files changed, 894 insertions(+), 447 deletions(-) rename native-engine/datafusion-ext-commons/src/algorithm/{rdx_tournament_tree.rs => rdx_queue.rs} (89%) rename native-engine/datafusion-ext-commons/src/algorithm/{rdxsort.rs => rdx_sort.rs} (100%) create mode 100644 native-engine/datafusion-ext-plans/src/common/offsetted.rs create mode 100644 native-engine/datafusion-ext-plans/src/common/stream_exec.rs diff --git a/.github/workflows/tpcds.yml b/.github/workflows/tpcds.yml index d5aff038..fde34011 100644 --- a/.github/workflows/tpcds.yml +++ b/.github/workflows/tpcds.yml @@ -2,7 +2,6 @@ name: TPC-DS on: workflow_dispatch: - push: pull_request_target: types: - opened diff --git a/native-engine/blaze-jni-bridge/src/conf.rs b/native-engine/blaze-jni-bridge/src/conf.rs index 7e65bfa0..cd2fea0d 100644 --- a/native-engine/blaze-jni-bridge/src/conf.rs +++ b/native-engine/blaze-jni-bridge/src/conf.rs @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use datafusion::common::Result; +use datafusion::common::{DataFusionError, Result}; -use crate::{jni_call_static, jni_get_string, jni_new_string}; +use crate::{is_jni_bridge_inited, jni_call_static, jni_get_string, jni_new_string}; macro_rules! define_conf { ($conftype:ty, $name:ident) => { @@ -42,10 +42,18 @@ define_conf!(BooleanConf, PARQUET_ENABLE_BLOOM_FILTER); define_conf!(StringConf, SPARK_IO_COMPRESSION_CODEC); define_conf!(IntConf, SPARK_TASK_CPUS); define_conf!(StringConf, SPILL_COMPRESSION_CODEC); +define_conf!(BooleanConf, SMJ_FALLBACK_ENABLE); +define_conf!(IntConf, SMJ_FALLBACK_ROWS_THRESHOLD); +define_conf!(IntConf, SMJ_FALLBACK_MEM_SIZE_THRESHOLD); pub trait BooleanConf { fn key(&self) -> &'static str; fn value(&self) -> Result { + if !is_jni_bridge_inited() { + return Err(DataFusionError::Execution(format!( + "JNIEnv not initialized" + ))); + } let key = jni_new_string!(self.key())?; jni_call_static!(BlazeConf.booleanConf(key.as_obj()) -> bool) } @@ -54,6 +62,11 @@ pub trait BooleanConf { pub trait IntConf { fn key(&self) -> &'static str; fn value(&self) -> Result { + if !is_jni_bridge_inited() { + return Err(DataFusionError::Execution(format!( + "JNIEnv not initialized" + ))); + } let key = jni_new_string!(self.key())?; jni_call_static!(BlazeConf.intConf(key.as_obj()) -> i32) } @@ -62,6 +75,11 @@ pub trait IntConf { pub trait LongConf { fn key(&self) -> &'static str; fn value(&self) -> Result { + if !is_jni_bridge_inited() { + return Err(DataFusionError::Execution(format!( + "JNIEnv not initialized" + ))); + } let key = jni_new_string!(self.key())?; jni_call_static!(BlazeConf.longConf(key.as_obj()) -> i64) } @@ -70,6 +88,11 @@ pub trait LongConf { pub trait DoubleConf { fn key(&self) -> &'static str; fn value(&self) -> Result { + if !is_jni_bridge_inited() { + return Err(DataFusionError::Execution(format!( + "JNIEnv not initialized" + ))); + } let key = jni_new_string!(self.key())?; jni_call_static!(BlazeConf.doubleConf(key.as_obj()) -> f64) } @@ -78,6 +101,11 @@ pub trait DoubleConf { pub trait StringConf { fn key(&self) -> &'static str; fn value(&self) -> Result { + if !is_jni_bridge_inited() { + return Err(DataFusionError::Execution(format!( + "JNIEnv not initialized" + ))); + } let key = jni_new_string!(self.key())?; let value = jni_get_string!( jni_call_static!(BlazeConf.stringConf(key.as_obj()) -> JObject)? diff --git a/native-engine/datafusion-ext-commons/src/algorithm/mod.rs b/native-engine/datafusion-ext-commons/src/algorithm/mod.rs index 2a562a24..2cf52d6c 100644 --- a/native-engine/datafusion-ext-commons/src/algorithm/mod.rs +++ b/native-engine/datafusion-ext-commons/src/algorithm/mod.rs @@ -13,5 +13,5 @@ // limitations under the License. pub mod loser_tree; -pub mod rdx_tournament_tree; -pub mod rdxsort; +pub mod rdx_queue; +pub mod rdx_sort; diff --git a/native-engine/datafusion-ext-commons/src/algorithm/rdx_tournament_tree.rs b/native-engine/datafusion-ext-commons/src/algorithm/rdx_queue.rs similarity index 89% rename from native-engine/datafusion-ext-commons/src/algorithm/rdx_tournament_tree.rs rename to native-engine/datafusion-ext-commons/src/algorithm/rdx_queue.rs index 663f8920..419cbcea 100644 --- a/native-engine/datafusion-ext-commons/src/algorithm/rdx_tournament_tree.rs +++ b/native-engine/datafusion-ext-commons/src/algorithm/rdx_queue.rs @@ -16,13 +16,13 @@ use std::ops::{Deref, DerefMut}; use unchecked_index::UncheckedIndex; -pub trait KeyForRadixTournamentTree { +pub trait KeyForRadixQueue { fn rdx(&self) -> usize; } /// An implementation of the radix tournament tree /// with time complexity of sorting all values: O(n + K) -pub struct RadixTournamentTree { +pub struct RadixQueue { num_keys: usize, cur_rdx: usize, values: UncheckedIndex>, @@ -31,7 +31,7 @@ pub struct RadixTournamentTree { } #[allow(clippy::len_without_is_empty)] -impl RadixTournamentTree { +impl RadixQueue { pub fn new(values: Vec, num_keys: usize) -> Self { let num_keys = num_keys + 1; // avoid overflow let num_values = values.len(); @@ -117,12 +117,12 @@ impl RadixTournamentTree { /// A PeekMut structure to the loser tree, used to get smallest value and auto /// adjusting after dropped. -pub struct RadixTournamentTreePeekMut<'a, T: KeyForRadixTournamentTree> { - tree: &'a mut RadixTournamentTree, +pub struct RadixTournamentTreePeekMut<'a, T: KeyForRadixQueue> { + tree: &'a mut RadixQueue, dirty: bool, } -impl Deref for RadixTournamentTreePeekMut<'_, T> { +impl Deref for RadixTournamentTreePeekMut<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { @@ -130,7 +130,7 @@ impl Deref for RadixTournamentTreePeekMut<'_, T> { } } -impl DerefMut for RadixTournamentTreePeekMut<'_, T> { +impl DerefMut for RadixTournamentTreePeekMut<'_, T> { fn deref_mut(&mut self) -> &mut Self::Target { self.dirty = true; &mut self.tree.values[self @@ -142,7 +142,7 @@ impl DerefMut for RadixTournamentTreePeekMut<'_, T } } -impl Drop for RadixTournamentTreePeekMut<'_, T> { +impl Drop for RadixTournamentTreePeekMut<'_, T> { fn drop(&mut self) { if self.dirty { self.tree.adjust_tree(); @@ -155,7 +155,7 @@ mod test { use itertools::Itertools; use rand::Rng; - use crate::algorithm::rdx_tournament_tree::{KeyForRadixTournamentTree, RadixTournamentTree}; + use crate::algorithm::rdx_queue::{KeyForRadixQueue, RadixQueue}; #[test] fn fuzztest() { @@ -184,12 +184,12 @@ mod test { row_idx: usize, values: Vec, } - impl KeyForRadixTournamentTree for Cursor { + impl KeyForRadixQueue for Cursor { fn rdx(&self) -> usize { self.values.get(self.row_idx).cloned().unwrap_or(u64::MAX) as usize } } - let mut loser_tree = RadixTournamentTree::new( + let mut loser_tree = RadixQueue::new( nodes .into_iter() .map(|node| Cursor { diff --git a/native-engine/datafusion-ext-commons/src/algorithm/rdxsort.rs b/native-engine/datafusion-ext-commons/src/algorithm/rdx_sort.rs similarity index 100% rename from native-engine/datafusion-ext-commons/src/algorithm/rdxsort.rs rename to native-engine/datafusion-ext-commons/src/algorithm/rdx_sort.rs diff --git a/native-engine/datafusion-ext-commons/src/lib.rs b/native-engine/datafusion-ext-commons/src/lib.rs index 5970d3d2..44ae2fa4 100644 --- a/native-engine/datafusion-ext-commons/src/lib.rs +++ b/native-engine/datafusion-ext-commons/src/lib.rs @@ -18,10 +18,7 @@ #![feature(slice_swap_unchecked)] #![feature(vec_into_raw_parts)] -use blaze_jni_bridge::{ - conf::{IntConf, BATCH_SIZE}, - is_jni_bridge_inited, -}; +use blaze_jni_bridge::conf::{IntConf, BATCH_SIZE}; use once_cell::sync::OnceCell; use unchecked_index::UncheckedIndex; @@ -71,17 +68,8 @@ macro_rules! downcast_any { } pub fn batch_size() -> usize { - const CACHED_BATCH_SIZE: OnceCell = OnceCell::new(); - let batch_size = *CACHED_BATCH_SIZE - .get_or_try_init(|| { - if is_jni_bridge_inited() { - BATCH_SIZE.value() - } else { - Ok(10000) // for testing - } - }) - .expect("error getting configured batch size") as usize; - batch_size + const CACHED_BATCH_SIZE: OnceCell = OnceCell::new(); + *CACHED_BATCH_SIZE.get_or_init(|| BATCH_SIZE.value().unwrap_or(10000) as usize) } // bigger for better radix sort performance diff --git a/native-engine/datafusion-ext-plans/src/agg/agg_ctx.rs b/native-engine/datafusion-ext-plans/src/agg/agg_ctx.rs index 16ae3f39..b54616f1 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg_ctx.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg_ctx.rs @@ -26,7 +26,6 @@ use arrow::{ use blaze_jni_bridge::{ conf, conf::{DoubleConf, IntConf}, - is_jni_bridge_inited, }; use datafusion::{ common::{cast::as_binary_array, Result}, @@ -164,14 +163,10 @@ impl AggContext { )?; let (partial_skipping_ratio, partial_skipping_min_rows) = if supports_partial_skipping { - if is_jni_bridge_inited() { - ( - conf::PARTIAL_AGG_SKIPPING_RATIO.value()?, - conf::PARTIAL_AGG_SKIPPING_MIN_ROWS.value()? as usize, - ) - } else { - (0.999, 20000) // only for testing - } + ( + conf::PARTIAL_AGG_SKIPPING_RATIO.value().unwrap_or(0.999), + conf::PARTIAL_AGG_SKIPPING_MIN_ROWS.value().unwrap_or(20000) as usize, + ) } else { Default::default() }; diff --git a/native-engine/datafusion-ext-plans/src/agg/agg_table.rs b/native-engine/datafusion-ext-plans/src/agg/agg_table.rs index 2c672ad5..1a8354c0 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg_table.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg_table.rs @@ -27,8 +27,8 @@ use datafusion::{ }; use datafusion_ext_commons::{ algorithm::{ - rdx_tournament_tree::{KeyForRadixTournamentTree, RadixTournamentTree}, - rdxsort::radix_sort_by_key, + rdx_queue::{KeyForRadixQueue, RadixQueue}, + rdx_sort::radix_sort_by_key, }, batch_size, df_execution_err, downcast_any, io::{read_bytes_slice, read_len, write_len}, @@ -215,8 +215,8 @@ impl AggTable { // create a radix tournament tree to do the merging // the mem-table and at least one spill should be in the tree - let mut cursors: RadixTournamentTree = - RadixTournamentTree::new(cursors, NUM_SPILL_BUCKETS); + let mut cursors: RadixQueue = + RadixQueue::new(cursors, NUM_SPILL_BUCKETS); assert!(cursors.len() > 0); let mut map = AggHashMap::default(); @@ -698,7 +698,7 @@ impl<'a> RecordsSpillCursor<'a> { } } -impl<'a> KeyForRadixTournamentTree for RecordsSpillCursor<'a> { +impl<'a> KeyForRadixQueue for RecordsSpillCursor<'a> { fn rdx(&self) -> usize { self.cur_bucket_idx } diff --git a/native-engine/datafusion-ext-plans/src/broadcast_join_build_hash_map_exec.rs b/native-engine/datafusion-ext-plans/src/broadcast_join_build_hash_map_exec.rs index 047a682c..63d71e1e 100644 --- a/native-engine/datafusion-ext-plans/src/broadcast_join_build_hash_map_exec.rs +++ b/native-engine/datafusion-ext-plans/src/broadcast_join_build_hash_map_exec.rs @@ -18,13 +18,22 @@ use std::{ sync::Arc, }; -use arrow::{array::RecordBatch, datatypes::SchemaRef}; +use arrow::{ + array::{new_null_array, RecordBatch}, + datatypes::SchemaRef, +}; +use arrow_schema::DataType; +use blaze_jni_bridge::{ + conf, + conf::{BooleanConf, IntConf}, +}; use datafusion::{ common::Result, execution::{SendableRecordBatchStream, TaskContext}, physical_expr::{EquivalenceProperties, Partitioning, PhysicalExpr}, physical_plan::{ - metrics::{ExecutionPlanMetricsSet, MetricsSet}, + metrics::{ExecutionPlanMetricsSet, MetricsSet, Time}, + stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties, }, @@ -34,8 +43,12 @@ use futures::StreamExt; use once_cell::sync::OnceCell; use crate::{ - common::{execution_context::ExecutionContext, timer_helper::TimerHelper}, + common::{ + execution_context::ExecutionContext, stream_exec::create_record_batch_stream_exec, + timer_helper::TimerHelper, + }, joins::join_hash_map::{join_hash_map_schema, JoinHashMap}, + sort_exec::create_default_ascending_sort_exec, }; pub struct BroadcastJoinBuildHashMapExec { @@ -111,7 +124,8 @@ impl ExecutionPlan for BroadcastJoinBuildHashMapExec { ) -> Result { let exec_ctx = ExecutionContext::new(context, partition, self.schema(), &self.metrics); let input = exec_ctx.execute(&self.input)?; - execute_build_hash_map(input, self.keys.clone(), exec_ctx) + let build_time = exec_ctx.register_timer_metric("build_time"); + execute_build_hash_map(input, self.keys.clone(), exec_ctx, build_time) } fn metrics(&self) -> Option { @@ -119,42 +133,97 @@ impl ExecutionPlan for BroadcastJoinBuildHashMapExec { } } -pub fn collect_hash_map( - data_schema: SchemaRef, - data_batches: Vec, - keys: Vec>, -) -> Result { - let data_batch = coalesce_batches_unchecked(data_schema, &data_batches); - let hash_map = JoinHashMap::create_from_data_batch(data_batch, &keys)?; - Ok(hash_map) -} - -fn execute_build_hash_map( +pub fn execute_build_hash_map( mut input: SendableRecordBatchStream, keys: Vec>, exec_ctx: Arc, + build_time: Time, ) -> Result { // output hash map batches as stream Ok(exec_ctx .clone() .output_with_sender("BuildHashMap", move |sender| async move { - let elapsed_compute = exec_ctx.baseline_metrics().elapsed_compute().clone(); - let _timer = elapsed_compute.timer(); + sender.exclude_time(&build_time); + let _timer = build_time.timer(); + + let smj_fallback_enabled = conf::SMJ_FALLBACK_ENABLE.value().unwrap_or(false); + let smj_fallback_rows_threshold = conf::SMJ_FALLBACK_ROWS_THRESHOLD + .value() + .unwrap_or(i32::MAX) as usize; + let smj_fallback_mem_threshold = conf::SMJ_FALLBACK_MEM_SIZE_THRESHOLD + .value() + .unwrap_or(i32::MAX) as usize; - // collect all input batches - let mut data_batches = vec![]; - while let Some(batch) = elapsed_compute + let data_schema = input.schema(); + let mut staging_batches: Vec = vec![]; + let mut staging_num_rows = 0; + let mut stating_mem_size = 0; + let mut fallback_to_sorted = false; + + while let Some(batch) = build_time .exclude_timer_async(input.next()) .await .transpose()? { - data_batches.push(batch); + staging_batches.push(batch.clone()); + if smj_fallback_enabled { + staging_num_rows += batch.num_rows(); + stating_mem_size += batch.get_array_memory_size(); + + // fallback if staging data is too large + if staging_num_rows > smj_fallback_rows_threshold + || stating_mem_size > smj_fallback_mem_threshold + { + fallback_to_sorted = true; + break; + } + } } - // build hash map - let data_schema = input.schema(); - let hash_map = collect_hash_map(data_schema, data_batches, keys)?; - sender.send(hash_map.into_hash_map_batch()?).await; + // no fallbacks - generate one hashmap batch + if !fallback_to_sorted { + let data_batch = + coalesce_batches_unchecked(data_schema, &std::mem::take(&mut staging_batches)); + let hash_map = JoinHashMap::create_from_data_batch(data_batch, &keys)?; + sender.send(hash_map.into_hash_map_batch()?).await; + exec_ctx + .baseline_metrics() + .elapsed_compute() + .add_duration(build_time.duration()); + return Ok(()); + } + + // fallback to sort-merge join + // sort all input data + let input: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new( + data_schema, + futures::stream::iter(staging_batches.into_iter().map(|batch| Ok(batch))) + .chain(input), + )); + let input_exec = create_record_batch_stream_exec(input, exec_ctx.partition_id())?; + let sort_exec = create_default_ascending_sort_exec(input_exec, &keys); + let mut sorted_stream = + sort_exec.execute(exec_ctx.partition_id(), exec_ctx.task_ctx())?; + + // append a null table data column + let hash_map_batch_schema = join_hash_map_schema(&sorted_stream.schema()); + while let Some(batch) = sorted_stream.next().await.transpose()? { + let null_table_data_column = new_null_array(&DataType::Binary, batch.num_rows()); + let sorted_hash_map_batch = RecordBatch::try_new( + hash_map_batch_schema.clone(), + batch + .columns() + .iter() + .cloned() + .chain(Some(null_table_data_column)) + .collect(), + )?; + sender.send(sorted_hash_map_batch).await; + } + exec_ctx + .baseline_metrics() + .elapsed_compute() + .add_duration(build_time.duration()); Ok(()) })) } diff --git a/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs b/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs index 4d828934..a05d33e9 100644 --- a/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs +++ b/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs @@ -18,16 +18,18 @@ use std::{ future::Future, pin::Pin, sync::{Arc, Weak}, + time::Duration, }; use arrow::{ array::RecordBatch, - compute::{concat_batches, SortOptions}, + compute::SortOptions, datatypes::{DataType, SchemaRef}, }; +use arrow_schema::Schema; use async_trait::async_trait; use datafusion::{ - common::{DataFusionError, JoinSide, Result, Statistics}, + common::{JoinSide, Result, Statistics}, execution::context::TaskContext, physical_expr::{EquivalenceProperties, PhysicalExprRef}, physical_plan::{ @@ -40,14 +42,17 @@ use datafusion::{ }; use datafusion_ext_commons::{batch_size, df_execution_err}; use futures::{StreamExt, TryStreamExt}; +use futures_util::stream::Peekable; use hashbrown::HashMap; use once_cell::sync::OnceCell; use parking_lot::Mutex; use crate::{ + broadcast_join_build_hash_map_exec::execute_build_hash_map, common::{ column_pruning::ExecuteWithColumnPruning, execution_context::{ExecutionContext, WrappedRecordBatchSender}, + stream_exec::create_record_batch_stream_exec, timer_helper::TimerHelper, }, joins::{ @@ -67,6 +72,8 @@ use crate::{ join_utils::{JoinType, JoinType::*}, JoinParams, JoinProjection, }, + sort_exec::create_default_ascending_sort_exec, + sort_merge_join_exec::SortMergeJoinExec, }; #[derive(Debug)] @@ -315,7 +322,9 @@ async fn execute_join_with_map( build_output_time: Time, sender: Arc, ) -> Result<()> { - let _timer = exec_ctx.baseline_metrics().elapsed_compute().timer(); + let elapsed_compute = exec_ctx.baseline_metrics().elapsed_compute().clone(); + let _timer = elapsed_compute.timer(); + let mut joiner: Pin> = match broadcast_side { JoinSide::Left => match join_params.join_type { Inner => Box::pin(RProbedInnerJoiner::new(join_params, map, sender)), @@ -367,6 +376,96 @@ async fn execute_join_with_map( Ok(()) } +async fn execute_join_with_smj_fallback( + probed: SendableRecordBatchStream, + built: SendableRecordBatchStream, + join_params: JoinParams, + broadcast_side: JoinSide, + exec_ctx: Arc, + sender: Arc, +) -> Result<()> { + // remove the table data column from the built stream + let built_schema = built.schema(); + let built_sorted: Arc = { + let removed_schema = { + let mut fields = built_schema.fields().to_vec(); + fields.pop(); + Arc::new(Schema::new(fields)) + }; + let remoted_stream = Box::pin(RecordBatchStreamAdapter::new( + removed_schema.clone(), + built.map(|batch| { + Ok({ + let mut batch = batch?; + batch.remove_column(batch.num_columns() - 1); + batch + }) + }), + )); + create_record_batch_stream_exec(remoted_stream, exec_ctx.partition_id())? + }; + + // create sorted streams, build side is already sorted + let (left_exec, right_exec) = match broadcast_side { + JoinSide::Left => ( + built_sorted, + create_default_ascending_sort_exec( + create_record_batch_stream_exec(probed, exec_ctx.partition_id())?, + &join_params.right_keys, + ), + ), + JoinSide::Right => ( + create_default_ascending_sort_exec( + create_record_batch_stream_exec(probed, exec_ctx.partition_id())?, + &join_params.left_keys, + ), + built_sorted, + ), + }; + + // run sort merge join + let mut smj_join_params = join_params.clone(); + smj_join_params.sort_options = vec![SortOptions::default(); join_params.left_keys.len()]; + + let smj_exec = Arc::new(SortMergeJoinExec::try_new_with_join_params( + left_exec.clone(), + right_exec.clone(), + smj_join_params, + )?); + let mut join_output = smj_exec.execute(exec_ctx.partition_id(), exec_ctx.task_ctx())?; + + // send all outputs + while let Some(batch) = join_output.next().await.transpose()? { + sender.send(batch).await; + } + + // elapsed_compute = sort time + merge time + let smj_time = exec_ctx.register_timer_metric("fallback_sort_merge_join_time"); + smj_time.add_duration(Duration::from_nanos( + left_exec + .metrics() + .and_then(|m| m.elapsed_compute()) + .unwrap_or(0) as u64, + )); + smj_time.add_duration(Duration::from_nanos( + right_exec + .metrics() + .and_then(|m| m.elapsed_compute()) + .unwrap_or(0) as u64, + )); + smj_time.add_duration(Duration::from_nanos( + smj_exec + .metrics() + .and_then(|m| m.elapsed_compute()) + .unwrap_or(0) as u64, + )); + exec_ctx + .baseline_metrics() + .elapsed_compute() + .add_duration(smj_time.duration()); + Ok(()) +} + async fn execute_join( left: Arc, right: Arc, @@ -392,112 +491,99 @@ async fn execute_join( JoinSide::Right => join_params.right_keys.clone(), }; - // fetch two sides asynchronously to eagerly fetch probed side - let (probed, map) = futures::try_join!( - async { - let probed_input = exec_ctx.stat_input(exec_ctx.execute(&probed_plan)?); - let probed_schema = probed_input.schema(); - let mut probed_peeked = Box::pin(probed_input.peekable()); - probed_peeked.as_mut().peek().await; - Ok(Box::pin(RecordBatchStreamAdapter::new( - probed_schema, - probed_peeked, - ))) - }, - async { - if is_built { - collect_join_hash_map( - exec_ctx.clone(), - cached_build_hash_map_id, - built_plan, - &map_keys, - build_time.clone(), - ) - .await - } else { - build_join_hash_map(exec_ctx.clone(), built_plan, &map_keys, build_time.clone()) - .await - } - } - )?; - - exec_ctx - .baseline_metrics() - .elapsed_compute() - .add_duration(build_time.duration()); - - execute_join_with_map( - probed, - map, - join_params, - broadcast_side, - exec_ctx, - probed_side_hash_time, - probed_side_search_time, - probed_side_compare_time, - build_output_time, - sender, + let probed_input = exec_ctx.stat_input(exec_ctx.execute(&probed_plan)?); + let built_input = if is_built { + exec_ctx.stat_input(exec_ctx.execute(&built_plan)?) + } else { + let data_input = exec_ctx.stat_input(exec_ctx.execute(&built_plan)?); + let built_schema = join_hash_map_schema(&data_input.schema()); + execute_build_hash_map( + data_input, + map_keys.clone(), + exec_ctx.with_new_output_schema(built_schema), + build_time.clone(), + )? + }; + let built_collected = collect_join_hash_map( + Box::pin(built_input.peekable()), + cached_build_hash_map_id.filter(|_| is_built), + &map_keys, + build_time, ) - .await -} - -async fn build_join_hash_map( - exec_ctx: Arc, - built_plan: Arc, - key_exprs: &[PhysicalExprRef], - build_time: Time, -) -> Result> { - let input = exec_ctx.execute_with_input_stats(&built_plan)?; - let data_schema = input.schema(); - let hash_map_schema = join_hash_map_schema(&data_schema); - let data_batches: Vec = input.try_collect().await?; - - let join_hash_map = build_time.with_timer(|| { - let data_batch = concat_batches(&data_schema, data_batches.iter())?; - if data_batch.num_rows() == 0 { - return Ok(Arc::new(JoinHashMap::create_empty( - hash_map_schema, - key_exprs, - )?)); + .await?; + + match built_collected { + CollectJoinHashMapResult::Map(map) => { + let join_with_map = execute_join_with_map( + probed_input, + map, + join_params, + broadcast_side, + exec_ctx, + probed_side_hash_time, + probed_side_search_time, + probed_side_compare_time, + build_output_time, + sender, + ); + join_with_map.await?; + } + CollectJoinHashMapResult::SortedStream(stream) => { + let built_input = Box::pin(RecordBatchStreamAdapter::new( + stream.get_ref().schema(), + stream, + )); + let join_with_smj_fallback = execute_join_with_smj_fallback( + probed_input, + built_input, + join_params, + broadcast_side, + exec_ctx, + sender, + ); + join_with_smj_fallback.await?; } + } + Ok(()) +} - let join_hash_map = JoinHashMap::create_from_data_batch(data_batch, key_exprs)?; - Ok::<_, DataFusionError>(Arc::new(join_hash_map)) - })?; - Ok(join_hash_map) +enum CollectJoinHashMapResult { + Map(Arc), + SortedStream(Pin>>), } async fn collect_join_hash_map( - exec_ctx: Arc, + input: Pin>>, cached_build_hash_map_id: Option, - built_plan: Arc, key_exprs: &[PhysicalExprRef], build_time: Time, -) -> Result> { +) -> Result { Ok(match cached_build_hash_map_id { Some(cached_id) => { get_cached_join_hash_map(&cached_id, || async { - let input = exec_ctx.execute_with_input_stats(&built_plan)?; collect_join_hash_map_without_caching(input, key_exprs, build_time).await }) .await? } - None => { - let input = exec_ctx.execute_with_input_stats(&built_plan)?; - let map = collect_join_hash_map_without_caching(input, key_exprs, build_time).await?; - Arc::new(map) - } + None => collect_join_hash_map_without_caching(input, key_exprs, build_time).await?, }) } async fn collect_join_hash_map_without_caching( - input: SendableRecordBatchStream, + mut input: Pin>>, key_exprs: &[PhysicalExprRef], build_time: Time, -) -> Result { - let hash_map_schema = input.schema(); - let hash_map_batches: Vec = input.try_collect().await?; +) -> Result { + let hash_map_schema = input.get_ref().schema(); + let is_smj_fallback_join = matches!( + input.as_mut().peek().await, + Some(Ok(batch)) if !JoinHashMap::record_batch_contains_hash_map(batch), + ); + if is_smj_fallback_join { + return Ok(CollectJoinHashMapResult::SortedStream(input)); + } + let hash_map_batches: Vec = input.try_collect().await?; build_time.with_timer(|| { let join_hash_map = match hash_map_batches.len() { 0 => JoinHashMap::create_empty(hash_map_schema, key_exprs)?, @@ -510,7 +596,7 @@ async fn collect_join_hash_map_without_caching( } n => return df_execution_err!("expect zero or one hash map batch, got {n}"), }; - Ok(join_hash_map) + Ok(CollectJoinHashMapResult::Map(Arc::new(join_hash_map))) }) } @@ -534,10 +620,10 @@ pub trait Joiner { fn num_output_rows(&self) -> usize; } -async fn get_cached_join_hash_map> + Send>( +async fn get_cached_join_hash_map> + Send>( cached_id: &str, init: impl FnOnce() -> Fut, -) -> Result> { +) -> Result { type Slot = Arc>>; static CACHED_JOIN_HASH_MAP: OnceCell>>> = OnceCell::new(); @@ -556,11 +642,17 @@ async fn get_cached_join_hash_map> + Se let mut slot = slot.lock().await; if let Some(cached) = slot.upgrade() { log::info!("got cached broadcast join hash map: ${cached_id}"); - Ok(cached) + Ok(CollectJoinHashMapResult::Map(cached)) } else { log::info!("collecting broadcast join hash map: ${cached_id}"); - let new = Arc::new(init().await?); - *slot = Arc::downgrade(&new); - Ok(new) + match init().await? { + CollectJoinHashMapResult::Map(map) => { + *slot = Arc::downgrade(&map); + Ok(CollectJoinHashMapResult::Map(map)) + } + CollectJoinHashMapResult::SortedStream(sorted_stream) => { + Ok(CollectJoinHashMapResult::SortedStream(sorted_stream)) + } + } } } diff --git a/native-engine/datafusion-ext-plans/src/common/execution_context.rs b/native-engine/datafusion-ext-plans/src/common/execution_context.rs index 90816751..1b951658 100644 --- a/native-engine/datafusion-ext-plans/src/common/execution_context.rs +++ b/native-engine/datafusion-ext-plans/src/common/execution_context.rs @@ -22,7 +22,7 @@ use std::{ }; use arrow::{array::RecordBatch, datatypes::SchemaRef}; -use blaze_jni_bridge::{conf, conf::BooleanConf, is_jni_bridge_inited, is_task_running}; +use blaze_jni_bridge::{conf, conf::BooleanConf, is_task_running}; use datafusion::{ common::Result, execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}, @@ -53,8 +53,8 @@ pub struct ExecutionContext { output_schema: SchemaRef, metrics: ExecutionPlanMetricsSet, baseline_metrics: BaselineMetrics, - spill_metrics: OnceCell, - input_stat_metrics: OnceCell>, + spill_metrics: Arc>, + input_stat_metrics: Arc>>, } impl ExecutionContext { @@ -70,8 +70,20 @@ impl ExecutionContext { output_schema, baseline_metrics: BaselineMetrics::new(&metrics, partition_id), metrics: metrics.clone(), - spill_metrics: OnceCell::new(), - input_stat_metrics: OnceCell::new(), + spill_metrics: Arc::default(), + input_stat_metrics: Arc::default(), + }) + } + + pub fn with_new_output_schema(&self, output_schema: SchemaRef) -> Arc { + Arc::new(Self { + task_ctx: self.task_ctx.clone(), + partition_id: self.partition_id, + output_schema, + metrics: self.metrics.clone(), + baseline_metrics: self.baseline_metrics.clone(), + spill_metrics: self.spill_metrics.clone(), + input_stat_metrics: self.input_stat_metrics.clone(), }) } @@ -326,7 +338,7 @@ impl InputBatchStatistics { metrics_set: &ExecutionPlanMetricsSet, partition: usize, ) -> Result> { - let enabled = is_jni_bridge_inited() && conf::INPUT_BATCH_STATISTICS_ENABLE.value()?; + let enabled = conf::INPUT_BATCH_STATISTICS_ENABLE.value().unwrap_or(false); Ok(enabled.then_some(Self::from_metrics_set(metrics_set, partition))) } diff --git a/native-engine/datafusion-ext-plans/src/common/mod.rs b/native-engine/datafusion-ext-plans/src/common/mod.rs index 648859a2..96005755 100644 --- a/native-engine/datafusion-ext-plans/src/common/mod.rs +++ b/native-engine/datafusion-ext-plans/src/common/mod.rs @@ -16,6 +16,8 @@ pub mod cached_exprs_evaluator; pub mod column_pruning; pub mod execution_context; pub mod ipc_compression; +pub mod offsetted; +pub mod stream_exec; pub mod timer_helper; pub trait SliceAsRawBytes { diff --git a/native-engine/datafusion-ext-plans/src/common/offsetted.rs b/native-engine/datafusion-ext-plans/src/common/offsetted.rs new file mode 100644 index 00000000..7e49dedd --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/common/offsetted.rs @@ -0,0 +1,192 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{marker::PhantomData, ops::Range}; + +use datafusion::common::Result; +use datafusion_ext_commons::algorithm::rdx_queue::{KeyForRadixQueue, RadixQueue}; +use num::PrimInt; + +pub struct Offsetted { + offsets: Vec, + data: T, +} + +impl Offsetted { + pub fn new(offsets: Vec, data: T) -> Self { + Self { offsets, data } + } + + pub fn offsets(&self) -> &[O] { + &self.offsets + } + + pub fn offset(&self, i: usize) -> Range { + self.offsets[i]..self.offsets[i + 1] + } + + pub fn data(&self) -> &T { + &self.data + } + + pub fn data_mut(&mut self) -> &mut T { + &mut self.data + } + + pub fn map_data(self, f: impl FnOnce(T) -> U) -> Offsetted { + Offsetted { + offsets: self.offsets, + data: f(self.data), + } + } + + pub fn try_map_data(self, f: impl FnOnce(T) -> Result) -> Result> { + Ok(Offsetted { + offsets: self.offsets, + data: f(self.data)?, + }) + } +} + +pub struct OffsettedCursor { + offsetted: Offsetted, + cur: usize, +} + +impl KeyForRadixQueue for OffsettedCursor { + fn rdx(&self) -> usize { + self.cur + } +} + +impl OffsettedCursor { + pub fn new(offsetted: Offsetted) -> Self { + let mut new = Self { offsetted, cur: 0 }; + new.skip_empty_partitions(); + new + } + + pub fn skip_empty_partitions(&mut self) { + let offsets = self.offsetted.offsets(); + while self.cur + 1 < offsets.len() && offsets[self.cur + 1] == offsets[self.cur] { + self.cur += 1; + } + } +} + +pub struct OffsettedMergeIterator<'a, O, T> { + num_partitions: usize, + cursors: RadixQueue>, + cur_partition_id: usize, + cur_offset: O, + merged_offsets: Vec, + _phantom: PhantomData<&'a ()>, +} + +impl<'a, O: PrimInt + 'a, T: 'a> OffsettedMergeIterator<'a, O, T> { + pub fn new(num_partitions: usize, offsets: Vec>) -> Self { + assert!( + !offsets.is_empty(), + "OffsettedSpillsMergeIterator got no spills" + ); + + let cursors = RadixQueue::new( + offsets + .into_iter() + .map(|offsetted| OffsettedCursor::new(offsetted)) + .collect(), + num_partitions, + ); + let cur_partition_id = cursors.peek().cur; + + Self { + num_partitions, + cursors, + cur_partition_id, + cur_offset: O::zero(), + merged_offsets: Default::default(), + _phantom: Default::default(), + } + } + + pub fn peek_next_partition_id(&self) -> usize { + self.cursors.peek().cur + } + + pub fn merged_offsets(&self) -> &[O] { + &self.merged_offsets + } + + pub fn next_partition_chunk<'z>( + &'z mut self, + ) -> Option<(usize, OffsettedMergePartitionChunkIterator<'a, 'z, O, T>)> { + let chunk_partition_id = self.peek_next_partition_id(); + if chunk_partition_id < self.num_partitions { + let chunk_iter = OffsettedMergePartitionChunkIterator { + merge_iter: self, + chunk_partition_id, + }; + return Some((chunk_partition_id, chunk_iter)); + } + None + } +} + +impl<'a, O: PrimInt + 'a, T: 'a> Iterator for OffsettedMergeIterator<'a, O, T> { + type Item = (usize, &'a mut T, Range); + + fn next(&mut self) -> Option { + let mut min_cursor = self.cursors.peek_mut(); + self.cur_partition_id = min_cursor.cur; + self.merged_offsets + .resize(self.cur_partition_id + 1, self.cur_offset); + + if min_cursor.cur >= self.num_partitions { + return None; // no more partitions + } + + let range = min_cursor.offsetted.offset(self.cur_partition_id); + let data = unsafe { + // safety: bypass lifetime checker + std::mem::transmute(min_cursor.offsetted.data_mut()) + }; + + // forward partition id in min_spill + self.cur_offset = self.cur_offset + range.end - range.start; + min_cursor.cur += 1; + min_cursor.skip_empty_partitions(); + + // return current reader + Some((self.cur_partition_id, data, range)) + } +} + +pub struct OffsettedMergePartitionChunkIterator<'a, 'z, O, T> { + merge_iter: &'z mut OffsettedMergeIterator<'a, O, T>, + chunk_partition_id: usize, +} + +impl<'a, O: PrimInt + 'a, T: 'a> Iterator for OffsettedMergePartitionChunkIterator<'a, '_, O, T> { + type Item = (&'a mut T, Range); + + fn next(&mut self) -> Option { + if self.merge_iter.peek_next_partition_id() == self.chunk_partition_id { + return self.merge_iter.next().map(|(_, data, range)| (data, range)); + } + None + } +} + +pub type OffsettedMergePartitionChunkIteratorBypassLifetimeCheck = + OffsettedMergePartitionChunkIterator<'static, 'static, O, T>; diff --git a/native-engine/datafusion-ext-plans/src/common/stream_exec.rs b/native-engine/datafusion-ext-plans/src/common/stream_exec.rs new file mode 100644 index 00000000..d7e4740e --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/common/stream_exec.rs @@ -0,0 +1,72 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow_schema::SchemaRef; +use datafusion::{ + common::Result, + execution::TaskContext, + physical_plan::{ + streaming::{PartitionStream, StreamingTableExec}, + EmptyRecordBatchStream, ExecutionPlan, SendableRecordBatchStream, + }, +}; +use parking_lot::Mutex; + +// wrap a record batch stream to datafusion execution plan +pub fn create_record_batch_stream_exec( + stream: SendableRecordBatchStream, + partition_id: usize, +) -> Result> { + let schema = stream.schema(); + let empty_partition_stream: Arc = Arc::new(SinglePartitionStream::new( + Box::pin(EmptyRecordBatchStream::new(schema.clone())), + )); + let mut streams: Vec> = (0..=partition_id) + .map(|_| empty_partition_stream.clone()) + .collect(); + streams[partition_id] = Arc::new(SinglePartitionStream::new(stream)); + + Ok(Arc::new(StreamingTableExec::try_new( + schema, + streams, + None, + vec![], + false, + None, + )?)) +} + +struct SinglePartitionStream(SchemaRef, Arc>); + +impl SinglePartitionStream { + fn new(stream: SendableRecordBatchStream) -> Self { + Self(stream.schema(), Arc::new(Mutex::new(stream))) + } +} + +impl PartitionStream for SinglePartitionStream { + fn schema(&self) -> &SchemaRef { + &self.0 + } + + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + let mut stream = self.1.lock(); + std::mem::replace( + &mut *stream, + Box::pin(EmptyRecordBatchStream::new(self.0.clone())), + ) + } +} diff --git a/native-engine/datafusion-ext-plans/src/joins/bhj/full_join.rs b/native-engine/datafusion-ext-plans/src/joins/bhj/full_join.rs index 3a9e8dc9..10d10ab6 100644 --- a/native-engine/datafusion-ext-plans/src/joins/bhj/full_join.rs +++ b/native-engine/datafusion-ext-plans/src/joins/bhj/full_join.rs @@ -53,7 +53,11 @@ pub struct JoinerParams { } impl JoinerParams { - const fn new(probe_side: ProbeSide, probe_side_outer: bool, build_side_outer: bool) -> Self { + pub const fn new( + probe_side: ProbeSide, + probe_side_outer: bool, + build_side_outer: bool, + ) -> Self { Self { probe_side, probe_side_outer, @@ -62,15 +66,15 @@ impl JoinerParams { } } -const LEFT_PROBED_INNER: JoinerParams = JoinerParams::new(L, false, false); -const LEFT_PROBED_LEFT: JoinerParams = JoinerParams::new(L, true, false); -const LEFT_PROBED_RIGHT: JoinerParams = JoinerParams::new(L, false, true); -const LEFT_PROBED_OUTER: JoinerParams = JoinerParams::new(L, true, true); +pub const LEFT_PROBED_INNER: JoinerParams = JoinerParams::new(L, false, false); +pub const LEFT_PROBED_LEFT: JoinerParams = JoinerParams::new(L, true, false); +pub const LEFT_PROBED_RIGHT: JoinerParams = JoinerParams::new(L, false, true); +pub const LEFT_PROBED_OUTER: JoinerParams = JoinerParams::new(L, true, true); -const RIGHT_PROBED_INNER: JoinerParams = JoinerParams::new(R, false, false); -const RIGHT_PROBED_LEFT: JoinerParams = JoinerParams::new(R, false, true); -const RIGHT_PROBED_RIGHT: JoinerParams = JoinerParams::new(R, true, false); -const RIGHT_PROBED_OUTER: JoinerParams = JoinerParams::new(R, true, true); +pub const RIGHT_PROBED_INNER: JoinerParams = JoinerParams::new(R, false, false); +pub const RIGHT_PROBED_LEFT: JoinerParams = JoinerParams::new(R, false, true); +pub const RIGHT_PROBED_RIGHT: JoinerParams = JoinerParams::new(R, true, false); +pub const RIGHT_PROBED_OUTER: JoinerParams = JoinerParams::new(R, true, true); pub type LProbedInnerJoiner = FullJoiner; pub type LProbedLeftJoiner = FullJoiner; diff --git a/native-engine/datafusion-ext-plans/src/joins/join_hash_map.rs b/native-engine/datafusion-ext-plans/src/joins/join_hash_map.rs index dcfaacbc..759e3b42 100644 --- a/native-engine/datafusion-ext-plans/src/joins/join_hash_map.rs +++ b/native-engine/datafusion-ext-plans/src/joins/join_hash_map.rs @@ -15,7 +15,7 @@ use std::{ fmt::{Debug, Formatter}, hash::{BuildHasher, Hasher}, - io::Cursor, + io::{Cursor, Read, Write}, mem::MaybeUninit, simd::{cmp::SimdPartialEq, Simd}, sync::Arc, @@ -100,8 +100,20 @@ impl Table { num_rows < 1073741824, "join hash table: number of rows exceeded 2^30: {num_rows}" ); - let hashes = join_create_hashes(num_rows, key_columns); + Self::craete_from_key_columns_and_hashes(num_rows, key_columns, hashes) + } + + fn craete_from_key_columns_and_hashes( + num_rows: usize, + key_columns: &[ArrayRef], + hashes: Vec, + ) -> Result { + assert!( + num_rows < 1073741824, + "join hash table: number of rows exceeded 2^30: {num_rows}" + ); + let key_is_valid = |row_idx| key_columns.iter().all(|col| col.is_valid(row_idx)); let mut mapped_indices = unchecked!(vec![]); let mut num_valid_items = 0; @@ -179,12 +191,10 @@ impl Table { }) } - pub fn load_from_raw_bytes(raw_bytes: &[u8]) -> Result { - let mut cursor = Cursor::new(raw_bytes); - + pub fn read_from(mut r: impl Read) -> Result { // read map - let num_valid_items = read_len(&mut cursor)?; - let map_mod_bits = read_len(&mut cursor)? as u32; + let num_valid_items = read_len(&mut r)?; + let map_mod_bits = read_len(&mut r)? as u32; let mut map = vec![ unsafe { // safety: no need to init to zeros @@ -193,13 +203,13 @@ impl Table { }; 1usize << map_mod_bits ]; - read_raw_slice(&mut map, &mut cursor)?; + read_raw_slice(&mut map, &mut r)?; // read mapped indices - let mapped_indices_len = read_len(&mut cursor)?; + let mapped_indices_len = read_len(&mut r)?; let mut mapped_indices = Vec::with_capacity(mapped_indices_len); for _ in 0..mapped_indices_len { - mapped_indices.push(read_len(&mut cursor)? as u32); + mapped_indices.push(read_len(&mut r)? as u32); } Ok(Self { @@ -210,38 +220,18 @@ impl Table { }) } - pub fn try_into_raw_bytes(self) -> Result> { - let mut raw_bytes = Vec::with_capacity( - (8 + self.mapped_indices.len() + size_of::()) - + (24 + self.map.len() * size_of::()), - ); - + pub fn write_to(self, mut w: impl Write) -> Result<()> { // write map - write_len(self.num_valid_items, &mut raw_bytes)?; - write_len(self.map_mod_bits as usize, &mut raw_bytes)?; - write_raw_slice(&self.map, &mut raw_bytes)?; + write_len(self.num_valid_items, &mut w)?; + write_len(self.map_mod_bits as usize, &mut w)?; + write_raw_slice(&self.map, &mut w)?; // write mapped indices - write_len(self.mapped_indices.len(), &mut raw_bytes)?; + write_len(self.mapped_indices.len(), &mut w)?; for &v in self.mapped_indices.as_slice() { - write_len(v as usize, &mut raw_bytes)?; - } - - raw_bytes.shrink_to_fit(); - Ok(raw_bytes) - } - - pub fn lookup(&self, hash: u32) -> MapValue { - let mut i = (hash % (1 << self.map_mod_bits)) as usize; - loop { - let hash_matched = self.map[i].hashes.simd_eq(Simd::splat(hash)); - let empty = self.map[i].hashes.simd_eq(Simd::splat(0)); - - if let Some(pos) = (hash_matched | empty).first_set() { - return self.map[i].values[pos]; - } - i += 1; + write_len(v as usize, &mut w)?; } + Ok(()) } pub fn lookup_many(&self, hashes: Vec) -> Vec { @@ -329,22 +319,40 @@ impl JoinHashMap { }) } + pub fn create_from_data_batch_and_hashes( + data_batch: RecordBatch, + key_columns: Vec, + hashes: Vec, + ) -> Result { + let table = + Table::craete_from_key_columns_and_hashes(data_batch.num_rows(), &key_columns, hashes)?; + + Ok(Self { + data_batch, + key_columns, + table, + }) + } pub fn create_empty(hash_map_schema: SchemaRef, key_exprs: &[PhysicalExprRef]) -> Result { let data_batch = RecordBatch::new_empty(hash_map_schema); Self::create_from_data_batch(data_batch, key_exprs) } + pub fn record_batch_contains_hash_map(batch: &RecordBatch) -> bool { + let table_data_column = batch.column(batch.num_columns() - 1); + table_data_column.is_valid(0) + } + pub fn load_from_hash_map_batch( hash_map_batch: RecordBatch, key_exprs: &[PhysicalExprRef], ) -> Result { let mut data_batch = hash_map_batch.clone(); - let table = Table::load_from_raw_bytes( - data_batch - .remove_column(data_batch.num_columns() - 1) - .as_binary::() - .value(0), - )?; + + let table_data_column = data_batch.remove_column(data_batch.num_columns() - 1); + let mut table_data = Cursor::new(table_data_column.as_binary::().value(0)); + let table = Table::read_from(&mut table_data)?; + let key_columns: Vec = key_exprs .iter() .map(|expr| { @@ -365,12 +373,17 @@ impl JoinHashMap { if self.data_batch.num_rows() == 0 { return Ok(RecordBatch::new_empty(schema)); } + let mut table_col_builder = BinaryBuilder::new(); - table_col_builder.append_value(&self.table.try_into_raw_bytes()?); + let mut table_data = vec![]; + self.table.write_to(&mut table_data)?; + table_col_builder.append_value(&table_data); + for _ in 1..self.data_batch.num_rows() { table_col_builder.append_null(); } let table_col: ArrayRef = Arc::new(table_col_builder.finish()); + Ok(RecordBatch::try_new( schema, vec![self.data_batch.columns().to_vec(), vec![table_col]].concat(), @@ -397,10 +410,6 @@ impl JoinHashMap { self.data_batch.num_rows() == 0 } - pub fn lookup(&self, hash: u32) -> MapValue { - self.table.lookup(hash) - } - pub fn lookup_many(&self, hashes: Vec) -> Vec { self.table.lookup_many(hashes) } @@ -454,7 +463,7 @@ pub fn join_create_hashes(num_rows: usize, key_columns: &[ArrayRef]) -> Vec } #[inline] -fn join_table_field() -> FieldRef { +pub fn join_table_field() -> FieldRef { static BHJ_KEY_FIELD: OnceCell = OnceCell::new(); BHJ_KEY_FIELD .get_or_init(|| Arc::new(Field::new("~TABLE", DataType::Binary, true))) diff --git a/native-engine/datafusion-ext-plans/src/joins/test.rs b/native-engine/datafusion-ext-plans/src/joins/test.rs index 074856e9..aad6f827 100644 --- a/native-engine/datafusion-ext-plans/src/joins/test.rs +++ b/native-engine/datafusion-ext-plans/src/joins/test.rs @@ -37,6 +37,7 @@ mod tests { broadcast_join_build_hash_map_exec::BroadcastJoinBuildHashMapExec, broadcast_join_exec::BroadcastJoinExec, joins::join_utils::{JoinType, JoinType::*}, + memmgr::MemManager, sort_merge_join_exec::SortMergeJoinExec, }; @@ -186,6 +187,7 @@ mod tests { on: JoinOn, join_type: JoinType, ) -> Result<(Vec, Vec)> { + MemManager::init(1000000); let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); let schema = build_join_schema_for_test(&left.schema(), &right.schema(), join_type)?; diff --git a/native-engine/datafusion-ext-plans/src/memmgr/spill.rs b/native-engine/datafusion-ext-plans/src/memmgr/spill.rs index 7c377529..ccf0461d 100644 --- a/native-engine/datafusion-ext-plans/src/memmgr/spill.rs +++ b/native-engine/datafusion-ext-plans/src/memmgr/spill.rs @@ -299,3 +299,30 @@ impl Write for IoTimeWriteWrapper { self.0.flush() } } + +pub struct OwnedSpillBufReader<'a> { + spill: Box, + buf_reader: BufReader>, +} + +impl<'a> OwnedSpillBufReader<'a> { + pub fn from(spill: Box) -> Self { + let buf_reader = unsafe { + // safety: bypass ownership and lifetime checker + std::mem::transmute(spill.get_buf_reader()) + }; + Self { spill, buf_reader } + } + + pub fn spill(&self) -> &Box { + &self.spill + } + + pub fn spill_mut(&mut self) -> &mut Box { + &mut self.spill + } + + pub fn buf_reader(&mut self) -> &mut BufReader> { + &mut self.buf_reader + } +} diff --git a/native-engine/datafusion-ext-plans/src/shuffle/buffered_data.rs b/native-engine/datafusion-ext-plans/src/shuffle/buffered_data.rs index ba4ef51f..4a29530e 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle/buffered_data.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle/buffered_data.rs @@ -20,22 +20,24 @@ use bytesize::ByteSize; use count_write::CountWrite; use datafusion::{common::Result, physical_plan::metrics::Time}; use datafusion_ext_commons::{ - algorithm::{ - rdx_tournament_tree::{KeyForRadixTournamentTree, RadixTournamentTree}, - rdxsort::radix_sort_by_key, - }, + algorithm::rdx_sort::radix_sort_by_key, arrow::{ array_size::ArraySize, selection::{create_batch_interleaver, BatchInterleaver}, }, compute_suggested_batch_size_for_output, df_execution_err, }; +use itertools::Itertools; use jni::objects::GlobalRef; #[cfg(test)] use parking_lot::Mutex; use crate::{ - common::{ipc_compression::IpcCompressionWriter, timer_helper::TimerHelper}, + common::{ + ipc_compression::IpcCompressionWriter, + offsetted::{Offsetted, OffsettedMergeIterator}, + timer_helper::TimerHelper, + }, shuffle::{ evaluate_hashes, evaluate_partition_ids, evaluate_range_partition_ids, evaluate_robin_partition_ids, rss::RssWriter, Partitioning, @@ -132,32 +134,22 @@ impl BufferedData { let num_partitions = self.partitioning.partition_count(); let mut writer = IpcCompressionWriter::new(CountWrite::from(&mut w)); let mut offsets = vec![]; - let mut offset = 0; let mut iter = self.into_sorted_batches()?; - while !iter.finished() { + while let Some((partition_id, batch_iter)) = iter.next_partition_chunk() { if !is_task_running() { df_execution_err!("task completed/killed")?; } - let cur_part_id = iter.cur_part_id(); - while offsets.len() <= cur_part_id as usize { - offsets.push(offset); // fill offsets of empty partitions - } - // write all batches with this part id - while iter.cur_part_id() == cur_part_id { - let batch = iter.next_batch()?; + offsets.resize(partition_id + 1, writer.inner().count()); + for batch in batch_iter { writer.write_batch(batch.num_rows(), batch.columns())?; } writer.finish_current_buf()?; - offset = writer.inner().count(); - offsets.push(offset); - } - while offsets.len() <= num_partitions { - offsets.push(offset); // fill offsets of empty partitions } + offsets.resize(num_partitions + 1, writer.inner().count()); - let compressed_size = ByteSize(offsets.last().cloned().unwrap_or_default() as u64); + let compressed_size = ByteSize(offsets.last().cloned().unwrap_or_default()); log::info!("all buffered data drained, compressed_size={compressed_size}"); Ok(offsets) } @@ -177,19 +169,14 @@ impl BufferedData { let mut iter = self.into_sorted_batches()?; let mut writer = IpcCompressionWriter::new(RssWriter::new(rss_partition_writer.clone(), 0)); - while !iter.finished() { + while let Some((partition_id, batch_iter)) = iter.next_partition_chunk() { if !is_task_running() { df_execution_err!("task completed/killed")?; } - let cur_part_id = iter.cur_part_id(); - writer.set_output(RssWriter::new( - rss_partition_writer.clone(), - cur_part_id as usize, - )); // write all batches with this part id - while iter.cur_part_id() == cur_part_id { - let batch = iter.next_batch()?; + writer.set_output(RssWriter::new(rss_partition_writer.clone(), partition_id)); + for batch in batch_iter { writer.write_batch(batch.num_rows(), batch.columns())?; } writer.finish_current_buf()?; @@ -199,31 +186,16 @@ impl BufferedData { Ok(()) } - fn into_sorted_batches(self) -> Result { - let sub_batch_size = - compute_suggested_batch_size_for_output(self.mem_used(), self.num_rows); - Ok(PartitionedBatchesIterator { - batch_interleaver: create_batch_interleaver(&self.sorted_batches, true)?, - cursors: RadixTournamentTree::new( - self.sorted_offsets - .into_iter() - .enumerate() - .map(|(idx, offsets)| { - let mut cur = PartCursor { - idx, - offsets, - parts_idx: 0, - }; - cur.skip_empty_parts(); - cur - }) - .collect(), - self.partitioning.partition_count(), - ), - num_output_rows: 0, - num_rows: self.num_rows, - batch_size: sub_batch_size, - }) + fn into_sorted_batches(self) -> Result> { + let num_rows = self.num_rows; + let sub_batch_size = compute_suggested_batch_size_for_output(self.mem_used(), num_rows); + let num_partitions = self.partitioning.partition_count(); + PartitionedBatchesIterator::try_new( + self.sorted_batches, + self.sorted_offsets, + sub_batch_size, + num_partitions, + ) } pub fn mem_used(&self) -> usize { @@ -235,76 +207,68 @@ impl BufferedData { } } -struct PartitionedBatchesIterator { +struct PartitionedBatchesIterator<'a> { batch_interleaver: BatchInterleaver, - cursors: RadixTournamentTree, - num_output_rows: usize, - num_rows: usize, + merge_iter: OffsettedMergeIterator<'a, u32, usize>, batch_size: usize, + last_chunk_partition_id: Option, } -impl PartitionedBatchesIterator { - pub fn cur_part_id(&self) -> u32 { - self.cursors.peek().rdx() as u32 +impl<'a> PartitionedBatchesIterator<'a> { + pub fn try_new( + batches: Vec, + batch_offsets: Vec>, + sub_batch_size: usize, + num_partitions: usize, + ) -> Result { + Ok(Self { + batch_interleaver: create_batch_interleaver(&batches, true)?, + merge_iter: OffsettedMergeIterator::new( + num_partitions, + batch_offsets + .into_iter() + .enumerate() + .map(|(idx, offsets)| Offsetted::new(offsets, idx)) + .collect(), + ), + batch_size: sub_batch_size, + last_chunk_partition_id: None, + }) } - pub fn finished(&self) -> bool { - self.num_output_rows >= self.num_rows - } + /// all iterators returned should have been fully consumed + pub fn next_partition_chunk( + &mut self, + ) -> Option<(usize, impl Iterator + 'a)> { + // safety: bypass lifetime checker + let batches_iter = + unsafe { std::mem::transmute::<_, &mut PartitionedBatchesIterator<'a>>(self) }; - pub fn next_batch(&mut self) -> Result { - let cur_batch_size = self.batch_size.min(self.num_rows - self.num_output_rows); - let cur_part_id = self.cur_part_id(); - let mut indices = Vec::with_capacity(cur_batch_size); + let (chunk_partition_id, chunk) = batches_iter.merge_iter.next_partition_chunk()?; - // add rows with same parition id under this cursor - while indices.len() < cur_batch_size { - let mut min_cursor = self.cursors.peek_mut(); - if min_cursor.rdx() as u32 != cur_part_id { - break; - } - let batch_idx = min_cursor.idx; - let min_offsets = &min_cursor.offsets; - let min_parts_idx = min_cursor.parts_idx; - let cur_offset_range = min_offsets[min_parts_idx]..min_offsets[min_parts_idx + 1]; - indices.extend(cur_offset_range.map(|offset| (batch_idx, offset as usize))); - - // forward to next non-empty partition - min_cursor.parts_idx += 1; - min_cursor.skip_empty_parts(); + // last chunk must be fully consumed + if batches_iter.last_chunk_partition_id == Some(chunk_partition_id) { + panic!("last chunk not fully consumed"); } - - let batch_interleaver = &mut self.batch_interleaver; - let output_batch = batch_interleaver(&indices)?; - self.num_output_rows += output_batch.num_rows(); - Ok(output_batch) - } -} - -struct PartCursor { - idx: usize, - offsets: Vec, - parts_idx: usize, -} - -impl PartCursor { - fn skip_empty_parts(&mut self) { - if self.parts_idx < self.num_partitions() { - if self.offsets[self.parts_idx + 1] == self.offsets[self.parts_idx] { - self.parts_idx += 1; - self.skip_empty_parts(); + batches_iter.last_chunk_partition_id = Some(chunk_partition_id); + + let batch_iter = chunk.batching(|chunk| { + let mut indices = vec![]; + for (batch_idx, range) in chunk { + indices.extend(range.map(|offset| (*batch_idx, offset as usize))); + if indices.len() >= batches_iter.batch_size { + break; + } } - } - } - fn num_partitions(&self) -> usize { - self.offsets.len() - 1 - } -} - -impl KeyForRadixTournamentTree for PartCursor { - fn rdx(&self) -> usize { - self.parts_idx + if indices.is_empty() { + return None; + } + let batch_interleaver = &mut batches_iter.batch_interleaver; + let output_batch = batch_interleaver(&indices).expect("error interleaving batches"); + return Some(output_batch); + }); + Some((chunk_partition_id, batch_iter)) } } @@ -323,22 +287,25 @@ fn sort_batches_by_partition_id( .iter() .enumerate() .flat_map(|(batch_idx, batch)| { - let mut part_ids: Vec = Vec::new(); - match partitioning { + let part_ids = match partitioning { Partitioning::HashPartitioning(..) => { // compute partition indices let hashes = evaluate_hashes(partitioning, &batch) .expect(&format!("error evaluating hashes with {partitioning}")); - part_ids = evaluate_partition_ids(hashes, partitioning.partition_count()); + evaluate_partition_ids(hashes, partitioning.partition_count()) } Partitioning::RoundRobinPartitioning(..) => { - part_ids = - evaluate_robin_partition_ids(partitioning, &batch, round_robin_start_rows); + let part_ids = evaluate_robin_partition_ids( + partitioning, + &batch, + round_robin_start_rows + ); round_robin_start_rows += batch.num_rows(); round_robin_start_rows %= partitioning.partition_count(); + part_ids } Partitioning::RangePartitioning(sort_expr, _, bounds) => { - part_ids = evaluate_range_partition_ids(&batch, sort_expr, bounds).unwrap(); + evaluate_range_partition_ids(&batch, sort_expr, bounds).unwrap() } _ => unreachable!("unsupported partitioning: {:?}", partitioning), }; diff --git a/native-engine/datafusion-ext-plans/src/shuffle/mod.rs b/native-engine/datafusion-ext-plans/src/shuffle/mod.rs index 435bca59..31ed7503 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle/mod.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle/mod.rs @@ -38,12 +38,12 @@ use datafusion_ext_commons::{arrow::array_size::ArraySize, spark_hash::create_mu use futures::StreamExt; use parking_lot::Mutex as SyncMutex; -use crate::{common::execution_context::ExecutionContext, memmgr::spill::Spill}; +use crate::common::execution_context::ExecutionContext; pub mod single_repartitioner; pub mod sort_repartitioner; -mod buffered_data; +pub mod buffered_data; mod rss; pub mod rss_single_repartitioner; pub mod rss_sort_repartitioner; @@ -104,11 +104,6 @@ impl dyn ShuffleRepartitioner { } } -struct ShuffleSpill { - spill: Box, - offsets: Vec, -} - #[derive(Debug, Clone)] pub enum Partitioning { /// Allocate batches using a round-robin algorithm and the specified number diff --git a/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs b/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs index e3bfb70e..88a4a0da 100644 --- a/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs +++ b/native-engine/datafusion-ext-plans/src/shuffle/sort_repartitioner.rs @@ -14,7 +14,7 @@ use std::{ fs::OpenOptions, - io::{BufReader, Read, Seek, Write}, + io::{Read, Write}, sync::{Arc, Weak}, }; @@ -25,20 +25,20 @@ use datafusion::{ common::{DataFusionError, Result}, physical_plan::metrics::Time, }; -use datafusion_ext_commons::{ - algorithm::rdx_tournament_tree::{KeyForRadixTournamentTree, RadixTournamentTree}, - arrow::array_size::ArraySize, - df_execution_err, -}; +use datafusion_ext_commons::{arrow::array_size::ArraySize, df_execution_err}; use futures::lock::Mutex; use crate::{ - common::{execution_context::ExecutionContext, timer_helper::TimerHelper}, + common::{ + execution_context::ExecutionContext, + offsetted::{Offsetted, OffsettedMergeIterator}, + timer_helper::TimerHelper, + }, memmgr::{ - spill::{try_new_spill, Spill}, + spill::{try_new_spill, OwnedSpillBufReader, Spill}, MemConsumer, MemConsumerInfo, MemManager, }, - shuffle::{buffered_data::BufferedData, Partitioning, ShuffleRepartitioner, ShuffleSpill}, + shuffle::{buffered_data::BufferedData, Partitioning, ShuffleRepartitioner}, }; pub struct SortShuffleRepartitioner { @@ -48,7 +48,7 @@ pub struct SortShuffleRepartitioner { output_data_file: String, output_index_file: String, data: Mutex, - spills: Mutex>, + spills: Mutex>>>, num_output_partitions: usize, output_io_time: Time, } @@ -100,7 +100,7 @@ impl MemConsumer for SortShuffleRepartitioner { let spill = tokio::task::spawn_blocking(move || { let mut spill = try_new_spill(&spill_metrics)?; let offsets = data.write(spill.get_buf_writer())?; - Ok::<_, DataFusionError>(ShuffleSpill { spill, offsets }) + Ok::<_, DataFusionError>(Offsetted::new(offsets, spill)) }) .await .expect("tokio spawn_blocking error")?; @@ -199,40 +199,17 @@ impl ShuffleRepartitioner for SortShuffleRepartitioner { return Ok(()); } - struct SpillCursor<'a> { - cur: usize, - reader: BufReader>, - offsets: Vec, - } - - impl<'a> KeyForRadixTournamentTree for SpillCursor<'a> { - fn rdx(&self) -> usize { - self.cur - } - } - - impl<'a> SpillCursor<'a> { - fn skip_empty_partitions(&mut self) { - let offsets = &self.offsets; - while self.cur + 1 < offsets.len() && offsets[self.cur + 1] == offsets[self.cur] { - self.cur += 1; - } - } - } - // write rest data into an in-memory buffer if !data.is_empty() { let mut spill = Box::new(vec![]); let writer = spill.get_buf_writer(); let offsets = data.write(writer)?; self.update_mem_used(spill.len()).await?; - spills.push(ShuffleSpill { spill, offsets }); + spills.push(Offsetted::new(offsets, spill)); } - let num_output_partitions = self.num_output_partitions; - let mut offsets = vec![0]; - // append partition in each spills + let num_output_partitions = self.num_output_partitions; let output_io_time = self.output_io_time.clone(); tokio::task::spawn_blocking(move || { let mut output_data = output_io_time.wrap_writer( @@ -250,57 +227,23 @@ impl ShuffleRepartitioner for SortShuffleRepartitioner { .open(&index_file)?, ); - if !spills.is_empty() { - // select partitions from spills - let mut cursors = RadixTournamentTree::new( - spills - .iter_mut() - .map(|spill| SpillCursor { - cur: 0, - reader: spill.spill.get_buf_reader(), - offsets: std::mem::take(&mut spill.offsets), - }) - .map(|mut spill| { - spill.skip_empty_partitions(); - spill - }) - .filter(|spill| spill.cur < spill.offsets.len()) - .collect(), - num_output_partitions, - ); - - let mut cur_partition_id = 0; - loop { - let mut min_spill = cursors.peek_mut(); - if min_spill.cur + 1 >= min_spill.offsets.len() { - break; - } - - while cur_partition_id < min_spill.cur { - offsets.push(output_data.0.stream_position()?); - cur_partition_id += 1; - } - let (spill_offset_start, spill_offset_end) = ( - min_spill.offsets[cur_partition_id], - min_spill.offsets[cur_partition_id + 1], - ); - - let spill_range = spill_offset_start as usize..spill_offset_end as usize; - let reader = &mut min_spill.reader; - std::io::copy(&mut reader.take(spill_range.len() as u64), &mut output_data)?; + let mut merge_iter = OffsettedMergeIterator::new( + num_output_partitions, + spills + .into_iter() + .map(|spill| spill.map_data(|s| OwnedSpillBufReader::from(s))) + .collect(), + ); - // forward partition id in min_spill - min_spill.cur += 1; - min_spill.skip_empty_partitions(); - } + while let Some((_partition_id, reader, range)) = merge_iter.next() { + let mut reader = reader.buf_reader().take(range.end - range.start); + std::io::copy(&mut reader, &mut output_data)?; } - - // add one extra offset at last to ease partition length computation - offsets.resize(num_output_partitions + 1, output_data.0.stream_position()?); + let offsets = merge_iter.merged_offsets(); // write index file let mut offsets_data = vec![]; - for offset in offsets { + for &offset in offsets { offsets_data.extend_from_slice(&(offset as i64).to_le_bytes()[..]); } output_index.write_all(&offsets_data)?; diff --git a/native-engine/datafusion-ext-plans/src/sort_exec.rs b/native-engine/datafusion-ext-plans/src/sort_exec.rs index ad602143..2c397fb7 100644 --- a/native-engine/datafusion-ext-plans/src/sort_exec.rs +++ b/native-engine/datafusion-ext-plans/src/sort_exec.rs @@ -37,7 +37,9 @@ use bytesize::ByteSize; use datafusion::{ common::{DataFusionError, Result, Statistics}, execution::context::TaskContext, - physical_expr::{expressions::Column, EquivalenceProperties, PhysicalSortExpr}, + physical_expr::{ + expressions::Column, EquivalenceProperties, PhysicalExprRef, PhysicalSortExpr, + }, physical_plan::{ metrics::{ExecutionPlanMetricsSet, MetricsSet}, DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, ExecutionPlanProperties, @@ -102,6 +104,23 @@ impl SortExec { } } +pub fn create_default_ascending_sort_exec( + input: Arc, + key_exprs: &[PhysicalExprRef], +) -> Arc { + Arc::new(SortExec::new( + input, + key_exprs + .iter() + .map(|e| PhysicalSortExpr { + expr: e.clone(), + options: Default::default(), + }) + .collect(), + None, + )) +} + impl DisplayAs for SortExec { fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { let exprs = self diff --git a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs index a54546b0..1065f0a8 100644 --- a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs +++ b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs @@ -57,6 +57,7 @@ pub struct SortMergeJoinExec { on: JoinOn, join_type: JoinType, sort_options: Vec, + join_params: OnceCell, schema: SchemaRef, metrics: ExecutionPlanMetricsSet, props: OnceCell, @@ -78,6 +79,32 @@ impl SortMergeJoinExec { on, join_type, sort_options, + join_params: OnceCell::new(), + metrics: ExecutionPlanMetricsSet::new(), + props: OnceCell::new(), + }) + } + + pub fn try_new_with_join_params( + left: Arc, + right: Arc, + join_params: JoinParams, + ) -> Result { + let on = join_params + .left_keys + .iter() + .zip(&join_params.right_keys) + .map(|(l, r)| (l.clone(), r.clone())) + .collect(); + + Ok(Self { + schema: join_params.output_schema.clone(), + left, + right, + on, + join_type: join_params.join_type, + sort_options: join_params.sort_options.clone(), + join_params: OnceCell::with_value(join_params), metrics: ExecutionPlanMetricsSet::new(), props: OnceCell::new(), }) @@ -132,7 +159,10 @@ impl SortMergeJoinExec { context: Arc, projection: Vec, ) -> Result { - let join_params = self.create_join_params(&projection)?; + let join_params = self + .join_params + .get_or_try_init(|| self.create_join_params(&projection))? + .clone(); let exec_ctx = ExecutionContext::new( context, partition, diff --git a/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java b/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java index ebf28e48..10927c8d 100644 --- a/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java +++ b/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java @@ -62,7 +62,16 @@ public enum BlazeConf { FORCE_SHUFFLED_HASH_JOIN("spark.blaze.forceShuffledHashJoin", false), // spark spill compression codec - SPILL_COMPRESSION_CODEC("spark.blaze.spill.compression.codec", "lz4"); + SPILL_COMPRESSION_CODEC("spark.blaze.spill.compression.codec", "lz4"), + + // enable hash join falling back to sort merge join when hash table is too big + SMJ_FALLBACK_ENABLE("spark.blaze.smjfallback.enable", false), + + // smj fallback threshold + SMJ_FALLBACK_ROWS_THRESHOLD("spark.blaze.smjfallback.rows.threshold", 10000000), + + // smj fallback threshold + SMJ_FALLBACK_MEM_SIZE_THRESHOLD("spark.blaze.smjfallback.mem.threshold", 134217728); public final String key; private final Object defaultValue; diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeHelper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeHelper.scala index fd10e440..a1bc628c 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeHelper.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeHelper.scala @@ -104,6 +104,7 @@ object NativeHelper extends Logging { "probed_side_search_time" -> nanoTimingMetric("Native.probed_side_search_time"), "probed_side_compare_time" -> nanoTimingMetric("Native.probed_side_compare_time"), "build_output_time" -> nanoTimingMetric("Native.build_output_time"), + "fallback_sort_merge_join_time" -> nanoTimingMetric("Native.fallback_sort_merge_join_time"), "mem_spill_count" -> metric("Native.mem_spill_count"), "mem_spill_size" -> sizeMetric("Native.mem_spill_size"), "mem_spill_iotime" -> nanoTimingMetric("Native.mem_spill_iotime"), diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala index a7a79a4f..da5ee2e2 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala @@ -61,6 +61,7 @@ abstract class NativeBroadcastJoinBase( "probed_side_search_time", "probed_side_compare_time", "build_output_time", + "fallback_sort_merge_join_time", "input_batch_count", "input_batch_mem_size", "input_row_count")) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffledHashJoinBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffledHashJoinBase.scala index 8b2052c4..2da724b6 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffledHashJoinBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeShuffledHashJoinBase.scala @@ -27,9 +27,7 @@ import org.apache.spark.sql.blaze.NativeConverters import org.apache.spark.sql.blaze.NativeHelper import org.apache.spark.sql.blaze.NativeRDD import org.apache.spark.sql.blaze.NativeSupports -import org.apache.spark.sql.catalyst.expressions.Ascending import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.execution.BinaryExecNode import org.blaze.{protobuf => pb} @@ -56,19 +54,12 @@ abstract class NativeShuffledHashJoinBase( "probed_side_search_time", "probed_side_compare_time", "build_output_time", + "fallback_sort_merge_join_time", "input_batch_count", "input_batch_mem_size", "input_row_count")) .toSeq: _*) - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - - private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { - // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. - keys.map(SortOrder(_, Ascending)) - } - private def nativeSchema = Util.getNativeSchema(output) private def nativeJoinOn = leftKeys.zip(rightKeys).map { case (leftKey, rightKey) =>