diff --git a/optd-datafusion-repr/src/cost/base_cost.rs b/optd-datafusion-repr/src/cost/base_cost.rs index 55a789ee..7b19e1ec 100644 --- a/optd-datafusion-repr/src/cost/base_cost.rs +++ b/optd-datafusion-repr/src/cost/base_cost.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::plan_nodes::{ - BinOpType, ColumnRefExpr, ConstantExpr, ConstantType, LogOpType, OptRelNode, UnOpType, + BinOpType, ColumnRefExpr, ConstantExpr, ConstantType, ExprList, LogOpType, OptRelNode, UnOpType, }; use crate::properties::column_ref::{ColumnRefPropertyBuilder, GroupColumnRefs}; use crate::{ @@ -11,8 +11,8 @@ use crate::{ use arrow_schema::{ArrowError, DataType}; use datafusion::arrow::array::{ Array, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array, Int16Array, - Int32Array, Int8Array, RecordBatch, RecordBatchIterator, RecordBatchReader, UInt16Array, - UInt32Array, UInt8Array, + Int32Array, Int8Array, RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray, + UInt16Array, UInt32Array, UInt8Array, }; use itertools::Itertools; use optd_core::{ @@ -22,6 +22,7 @@ use optd_core::{ }; use optd_gungnir::stats::hyperloglog::{self, HyperLogLog}; use optd_gungnir::stats::tdigest::{self, TDigest}; +use optd_gungnir::utils::arith_encoder; use serde::{Deserialize, Serialize}; fn compute_plan_node_cost>( @@ -181,6 +182,7 @@ impl DataFusionPerTableStats { | DataType::UInt32 | DataType::Float32 | DataType::Float64 + | DataType::Utf8 ) } @@ -222,6 +224,10 @@ impl DataFusionPerTableStats { val as f64 } + fn str_to_f64(string: &str) -> f64 { + arith_encoder::encode(string) + } + match col_type { DataType::Boolean => { generate_stats_for_col!({ col, distr, hll, BooleanArray, to_f64_safe }) @@ -256,6 +262,9 @@ impl DataFusionPerTableStats { DataType::Decimal128(_, _) => { generate_stats_for_col!({ col, distr, hll, Decimal128Array, i128_to_f64 }) } + DataType::Utf8 => { + generate_stats_for_col!({ col, distr, hll, StringArray, str_to_f64 }) + } _ => unreachable!(), } } @@ -323,6 +332,10 @@ const DEFAULT_EQ_SEL: f64 = 0.005; const DEFAULT_INEQ_SEL: f64 = 0.3333333333333333; // Default selectivity estimate for pattern-match operators such as LIKE const DEFAULT_MATCH_SEL: f64 = 0.005; +// Default selectivity if we have no information +const DEFAULT_UNK_SEL: f64 = 0.005; +// Default n-distinct estimate for derived columns or columns lacking statistics +const DEFAULT_N_DISTINCT: u64 = 200; const INVALID_SEL: f64 = 0.01; @@ -401,37 +414,33 @@ impl CostModel for OptCostM OptRelNodeTyp::PhysicalEmptyRelation => Self::cost(0.5, 0.01, 0.0), OptRelNodeTyp::PhysicalLimit => { let (row_cnt, compute_cost, _) = Self::cost_tuple(&children[0]); - let row_cnt = if let Some(context) = context { - if let Some(optimizer) = optimizer { - let mut fetch_expr = - optimizer.get_all_group_bindings(context.children_group_ids[2], false); - assert!( - fetch_expr.len() == 1, - "fetch expression should be the only expr in the group" - ); - let fetch_expr = fetch_expr.pop().unwrap(); - assert!( - matches!( - fetch_expr.typ, - OptRelNodeTyp::Constant(ConstantType::UInt64) - ), - "fetch type can only be UInt64" - ); - let fetch = ConstantExpr::from_rel_node(fetch_expr) - .unwrap() - .value() - .as_u64(); - // u64::MAX represents None - if fetch == u64::MAX { - row_cnt - } else { - row_cnt.min(fetch as f64) - } + let row_cnt = if let (Some(context), Some(optimizer)) = (context, optimizer) { + let mut fetch_expr = + optimizer.get_all_group_bindings(context.children_group_ids[2], false); + assert!( + fetch_expr.len() == 1, + "fetch expression should be the only expr in the group" + ); + let fetch_expr = fetch_expr.pop().unwrap(); + assert!( + matches!( + fetch_expr.typ, + OptRelNodeTyp::Constant(ConstantType::UInt64) + ), + "fetch type can only be UInt64" + ); + let fetch = ConstantExpr::from_rel_node(fetch_expr) + .unwrap() + .value() + .as_u64(); + // u64::MAX represents None + if fetch == u64::MAX { + row_cnt } else { - panic!("compute_cost() should not be called if optimizer is None") + row_cnt.min(fetch as f64) } } else { - panic!("compute_cost() should not be called if context is None") + (row_cnt * DEFAULT_UNK_SEL).max(1.0) }; Self::cost(row_cnt, compute_cost, 0.0) } @@ -499,10 +508,15 @@ impl CostModel for OptCostM Self::cost(row_cnt, row_cnt * row_cnt.ln_1p().max(1.0), 0.0) } OptRelNodeTyp::PhysicalAgg => { - let (row_cnt, _, _) = Self::cost_tuple(&children[0]); + let child_row_cnt = Self::row_cnt(&children[0]); + let row_cnt = self.get_agg_row_cnt(context, optimizer, child_row_cnt); let (_, compute_cost_1, _) = Self::cost_tuple(&children[1]); let (_, compute_cost_2, _) = Self::cost_tuple(&children[2]); - Self::cost(row_cnt, row_cnt * (compute_cost_1 + compute_cost_2), 0.0) + Self::cost( + row_cnt, + child_row_cnt * (compute_cost_1 + compute_cost_2), + 0.0, + ) } OptRelNodeTyp::List => { let compute_cost = children @@ -544,6 +558,58 @@ impl OptCostModel { } } + fn get_agg_row_cnt( + &self, + context: Option, + optimizer: Option<&CascadesOptimizer>, + child_row_cnt: f64, + ) -> f64 { + if let (Some(context), Some(optimizer)) = (context, optimizer) { + let group_by_id = context.children_group_ids[2]; + let mut group_by_exprs: Vec>> = + optimizer.get_all_group_bindings(group_by_id, false); + assert!( + group_by_exprs.len() == 1, + "ExprList expression should be the only expression in the GROUP BY group" + ); + let group_by = group_by_exprs.pop().unwrap(); + let group_by = ExprList::from_rel_node(group_by).unwrap(); + if group_by.is_empty() { + 1.0 + } else { + // Multiply the n-distinct of all the group by columns. + // TODO: improve with multi-dimensional n-distinct + let base_table_col_refs = optimizer + .get_property_by_group::(context.group_id, 1); + base_table_col_refs + .iter() + .take(group_by.len()) + .map(|col_ref| match col_ref { + ColumnRef::BaseTableColumnRef { table, col_idx } => { + let table_stats = self.per_table_stats_map.get(table); + let column_stats = table_stats.map(|table_stats| { + table_stats.per_column_stats_vec.get(*col_idx).unwrap() + }); + + if let Some(Some(column_stats)) = column_stats { + column_stats.ndistinct as f64 + } else { + // The column type is not supported or stats are missing. + DEFAULT_N_DISTINCT as f64 + } + } + ColumnRef::Derived => DEFAULT_N_DISTINCT as f64, + _ => panic!( + "GROUP BY base table column ref must either be derived or base table" + ), + }) + .product() + } + } else { + (child_row_cnt * DEFAULT_UNK_SEL).max(1.0) + } + } + /// The expr_tree input must be a "mixed expression tree" /// An "expression node" refers to a RelNode that returns true for is_expression() /// A "full expression tree" is where every node in the tree is an expression node diff --git a/optd-datafusion-repr/src/properties/schema.rs b/optd-datafusion-repr/src/properties/schema.rs index f4967156..04f2e7cb 100644 --- a/optd-datafusion-repr/src/properties/schema.rs +++ b/optd-datafusion-repr/src/properties/schema.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use optd_core::property::PropertyBuilder; use super::DEFAULT_NAME; -use crate::plan_nodes::{ConstantType, EmptyRelationData, OptRelNodeTyp}; +use crate::plan_nodes::{ConstantType, EmptyRelationData, FuncType, OptRelNodeTyp}; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Field { @@ -12,6 +12,18 @@ pub struct Field { pub typ: ConstantType, pub nullable: bool, } + +impl Field { + /// Generate a field that is only a place holder whose members are never used. + fn placeholder() -> Self { + Self { + name: DEFAULT_NAME.to_string(), + typ: ConstantType::Any, + nullable: true, + } + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Schema { pub fields: Vec, @@ -87,14 +99,18 @@ impl PropertyBuilder for SchemaPropertyBuilder { Schema { fields } } OptRelNodeTyp::LogOp(_) => Schema { - fields: vec![ - Field { - name: DEFAULT_NAME.to_string(), - typ: ConstantType::Any, - nullable: true - }; - children.len() - ], + fields: vec![Field::placeholder(); children.len()], + }, + OptRelNodeTyp::Agg => { + let mut group_by_schema = children[1].clone(); + let agg_schema = children[2].clone(); + group_by_schema.fields.extend(agg_schema.fields); + group_by_schema + } + OptRelNodeTyp::Func(FuncType::Agg(_)) => Schema { + // TODO: this is just a place holder now. + // The real type should be the column type. + fields: vec![Field::placeholder()], }, _ => Schema { fields: vec![] }, } diff --git a/optd-gungnir/src/stats/hyperloglog.rs b/optd-gungnir/src/stats/hyperloglog.rs index aca39eb2..f71baa7a 100644 --- a/optd-gungnir/src/stats/hyperloglog.rs +++ b/optd-gungnir/src/stats/hyperloglog.rs @@ -25,10 +25,17 @@ pub struct HyperLogLog { alpha: f64, // The normal HLL multiplier factor. } +// Serialize common data types for hashing (&str). +impl ByteSerializable for &str { + fn to_bytes(&self) -> Vec { + self.as_bytes().to_vec() + } +} + // Serialize common data types for hashing (String). impl ByteSerializable for String { fn to_bytes(&self) -> Vec { - self.as_bytes().to_vec() + self.as_str().to_bytes() } } diff --git a/optd-sqlplannertest/src/lib.rs b/optd-sqlplannertest/src/lib.rs index b62e1082..d69ca31d 100644 --- a/optd-sqlplannertest/src/lib.rs +++ b/optd-sqlplannertest/src/lib.rs @@ -140,7 +140,6 @@ impl DatafusionDBMS { task: &str, flags: &[String], ) -> Result<()> { - println!("task_explain(): called on sql={}", sql); use std::fmt::Write; let with_logical = flags.contains(&"with_logical".to_string()); diff --git a/optd-sqlplannertest/tests/tpch.planner.sql b/optd-sqlplannertest/tests/tpch.planner.sql index 99e00812..1e5dc603 100644 --- a/optd-sqlplannertest/tests/tpch.planner.sql +++ b/optd-sqlplannertest/tests/tpch.planner.sql @@ -1522,6 +1522,171 @@ PhysicalLimit { skip: 0, fetch: 20 } └── PhysicalScan { table: nation } */ +-- TPC-H Q11 +select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value +from + partsupp, + supplier, + nation +where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'CHINA' +group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0001000000 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'CHINA' + ) +order by + value desc; + +/* +LogicalSort +├── exprs:SortOrder { order: Desc } +│ └── #1 +└── LogicalProjection { exprs: [ #0, #1 ] } + └── LogicalJoin + ├── join_type: Inner + ├── cond:Gt + │ ├── Cast { cast_to: Decimal128(38, 15), expr: #1 } + │ └── #2 + ├── LogicalAgg + │ ├── exprs:Agg(Sum) + │ │ └── Mul + │ │ ├── #2 + │ │ └── Cast { cast_to: Decimal128(10, 0), expr: #1 } + │ ├── groups: [ #0 ] + │ └── LogicalProjection { exprs: [ #0, #1, #2 ] } + │ └── LogicalJoin + │ ├── join_type: Inner + │ ├── cond:Eq + │ │ ├── #3 + │ │ └── #4 + │ ├── LogicalProjection { exprs: [ #0, #2, #3, #5 ] } + │ │ └── LogicalJoin + │ │ ├── join_type: Inner + │ │ ├── cond:Eq + │ │ │ ├── #1 + │ │ │ └── #4 + │ │ ├── LogicalProjection { exprs: [ #0, #1, #2, #3 ] } + │ │ │ └── LogicalScan { table: partsupp } + │ │ └── LogicalProjection { exprs: [ #0, #3 ] } + │ │ └── LogicalScan { table: supplier } + │ └── LogicalProjection { exprs: [ #0 ] } + │ └── LogicalFilter + │ ├── cond:Eq + │ │ ├── #1 + │ │ └── "CHINA" + │ └── LogicalProjection { exprs: [ #0, #1 ] } + │ └── LogicalScan { table: nation } + └── LogicalProjection + ├── exprs:Cast + │ ├── cast_to: Decimal128(38, 15) + │ ├── expr:Mul + │ │ ├── Cast { cast_to: Float64, expr: #0 } + │ │ └── 0.0001 + + └── LogicalAgg + ├── exprs:Agg(Sum) + │ └── Mul + │ ├── #1 + │ └── Cast { cast_to: Decimal128(10, 0), expr: #0 } + ├── groups: [] + └── LogicalProjection { exprs: [ #0, #1 ] } + └── LogicalJoin + ├── join_type: Inner + ├── cond:Eq + │ ├── #2 + │ └── #3 + ├── LogicalProjection { exprs: [ #1, #2, #4 ] } + │ └── LogicalJoin + │ ├── join_type: Inner + │ ├── cond:Eq + │ │ ├── #0 + │ │ └── #3 + │ ├── LogicalProjection { exprs: [ #1, #2, #3 ] } + │ │ └── LogicalScan { table: partsupp } + │ └── LogicalProjection { exprs: [ #0, #3 ] } + │ └── LogicalScan { table: supplier } + └── LogicalProjection { exprs: [ #0 ] } + └── LogicalFilter + ├── cond:Eq + │ ├── #1 + │ └── "CHINA" + └── LogicalProjection { exprs: [ #0, #1 ] } + └── LogicalScan { table: nation } +PhysicalSort +├── exprs:SortOrder { order: Desc } +│ └── #1 +└── PhysicalProjection { exprs: [ #0, #1 ] } + └── PhysicalProjection { exprs: [ #0, #1 ] } + └── PhysicalNestedLoopJoin + ├── join_type: Inner + ├── cond:Gt + │ ├── Cast { cast_to: Decimal128(38, 15), expr: #1 } + │ └── #0 + ├── PhysicalProjection + │ ├── exprs:Cast + │ │ ├── cast_to: Decimal128(38, 15) + │ │ ├── expr:Mul + │ │ │ ├── Cast { cast_to: Float64, expr: #0 } + │ │ │ └── 0.0001 + + │ └── PhysicalAgg + │ ├── aggrs:Agg(Sum) + │ │ └── Mul + │ │ ├── #1 + │ │ └── Cast { cast_to: Decimal128(10, 0), expr: #0 } + │ ├── groups: [] + │ └── PhysicalProjection { exprs: [ #0, #1 ] } + │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #2 ], right_keys: [ #0 ] } + │ ├── PhysicalProjection { exprs: [ #1, #2, #4 ] } + │ │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] } + │ │ ├── PhysicalProjection { exprs: [ #1, #2, #3 ] } + │ │ │ └── PhysicalScan { table: partsupp } + │ │ └── PhysicalProjection { exprs: [ #0, #3 ] } + │ │ └── PhysicalScan { table: supplier } + │ └── PhysicalProjection { exprs: [ #0 ] } + │ └── PhysicalFilter + │ ├── cond:Eq + │ │ ├── #1 + │ │ └── "CHINA" + │ └── PhysicalProjection { exprs: [ #0, #1 ] } + │ └── PhysicalScan { table: nation } + └── PhysicalAgg + ├── aggrs:Agg(Sum) + │ └── Mul + │ ├── #2 + │ └── Cast { cast_to: Decimal128(10, 0), expr: #1 } + ├── groups: [ #0 ] + └── PhysicalProjection { exprs: [ #0, #1, #2 ] } + └── PhysicalHashJoin { join_type: Inner, left_keys: [ #3 ], right_keys: [ #0 ] } + ├── PhysicalProjection { exprs: [ #0, #2, #3, #5 ] } + │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #1 ], right_keys: [ #0 ] } + │ ├── PhysicalProjection { exprs: [ #0, #1, #2, #3 ] } + │ │ └── PhysicalScan { table: partsupp } + │ └── PhysicalProjection { exprs: [ #0, #3 ] } + │ └── PhysicalScan { table: supplier } + └── PhysicalProjection { exprs: [ #0 ] } + └── PhysicalFilter + ├── cond:Eq + │ ├── #1 + │ └── "CHINA" + └── PhysicalProjection { exprs: [ #0, #1 ] } + └── PhysicalScan { table: nation } +*/ + -- TPC-H Q12 SELECT l_shipmode, @@ -1852,55 +2017,56 @@ PhysicalSort ├── exprs:SortOrder { order: Asc } │ └── #0 └── PhysicalProjection { exprs: [ #0, #1, #2, #3, #4 ] } - └── PhysicalHashJoin { join_type: Inner, left_keys: [ #4 ], right_keys: [ #0 ] } - ├── PhysicalProjection { exprs: [ #0, #1, #2, #3, #5 ] } - │ └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] } - │ ├── PhysicalProjection { exprs: [ #0, #1, #2, #4 ] } - │ │ └── PhysicalScan { table: supplier } - │ └── PhysicalProjection { exprs: [ #0, #1 ] } - │ └── PhysicalAgg - │ ├── aggrs:Agg(Sum) - │ │ └── Mul - │ │ ├── #1 - │ │ └── Sub - │ │ ├── 1 - │ │ └── #2 - │ ├── groups: [ #0 ] - │ └── PhysicalProjection { exprs: [ #0, #1, #2 ] } - │ └── PhysicalFilter - │ ├── cond:And - │ │ ├── Geq - │ │ │ ├── #3 - │ │ │ └── 8401 - │ │ └── Lt - │ │ ├── #3 - │ │ └── 8491 - │ └── PhysicalProjection { exprs: [ #2, #5, #6, #10 ] } - │ └── PhysicalScan { table: lineitem } - └── PhysicalAgg - ├── aggrs:Agg(Max) - │ └── [ #0 ] - ├── groups: [] - └── PhysicalProjection { exprs: [ #1 ] } - └── PhysicalAgg - ├── aggrs:Agg(Sum) - │ └── Mul - │ ├── #1 - │ └── Sub - │ ├── 1 - │ └── #2 - ├── groups: [ #0 ] - └── PhysicalProjection { exprs: [ #0, #1, #2 ] } - └── PhysicalFilter - ├── cond:And - │ ├── Geq - │ │ ├── #3 - │ │ └── 8401 - │ └── Lt - │ ├── #3 - │ └── 8491 - └── PhysicalProjection { exprs: [ #2, #5, #6, #10 ] } - └── PhysicalScan { table: lineitem } + └── PhysicalProjection { exprs: [ #0, #1, #2, #3, #5, #6 ] } + └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #0 ] } + ├── PhysicalProjection { exprs: [ #0, #1, #2, #4 ] } + │ └── PhysicalScan { table: supplier } + └── PhysicalProjection { exprs: [ #0, #1, #2 ] } + └── PhysicalProjection { exprs: [ #1, #2, #0 ] } + └── PhysicalHashJoin { join_type: Inner, left_keys: [ #0 ], right_keys: [ #1 ] } + ├── PhysicalAgg + │ ├── aggrs:Agg(Max) + │ │ └── [ #0 ] + │ ├── groups: [] + │ └── PhysicalProjection { exprs: [ #1 ] } + │ └── PhysicalAgg + │ ├── aggrs:Agg(Sum) + │ │ └── Mul + │ │ ├── #1 + │ │ └── Sub + │ │ ├── 1 + │ │ └── #2 + │ ├── groups: [ #0 ] + │ └── PhysicalProjection { exprs: [ #0, #1, #2 ] } + │ └── PhysicalFilter + │ ├── cond:And + │ │ ├── Geq + │ │ │ ├── #3 + │ │ │ └── 8401 + │ │ └── Lt + │ │ ├── #3 + │ │ └── 8491 + │ └── PhysicalProjection { exprs: [ #2, #5, #6, #10 ] } + │ └── PhysicalScan { table: lineitem } + └── PhysicalAgg + ├── aggrs:Agg(Sum) + │ └── Mul + │ ├── #1 + │ └── Sub + │ ├── 1 + │ └── #2 + ├── groups: [ #0 ] + └── PhysicalProjection { exprs: [ #0, #1, #2 ] } + └── PhysicalFilter + ├── cond:And + │ ├── Geq + │ │ ├── #3 + │ │ └── 8401 + │ └── Lt + │ ├── #3 + │ └── 8491 + └── PhysicalProjection { exprs: [ #2, #5, #6, #10 ] } + └── PhysicalScan { table: lineitem } */ -- TPC-H Q17 diff --git a/optd-sqlplannertest/tests/tpch.yml b/optd-sqlplannertest/tests/tpch.yml index 2c3fe1dc..4670fa49 100644 --- a/optd-sqlplannertest/tests/tpch.yml +++ b/optd-sqlplannertest/tests/tpch.yml @@ -416,6 +416,37 @@ desc: TPC-H Q10 tasks: - explain:logical_optd,physical_optd +- sql: | + select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'CHINA' + group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0001000000 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'CHINA' + ) + order by + value desc; + desc: TPC-H Q11 + tasks: + - explain[with_logical]:logical_optd,physical_optd - sql: | SELECT l_shipmode,