Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

feat: add cost estimation for agg #144

Merged
merged 9 commits into from
Mar 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 99 additions & 33 deletions optd-datafusion-repr/src/cost/base_cost.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{collections::HashMap, sync::Arc};

use crate::plan_nodes::{
BinOpType, ColumnRefExpr, ConstantExpr, ConstantType, LogOpType, OptRelNode, UnOpType,
BinOpType, ColumnRefExpr, ConstantExpr, ConstantType, ExprList, LogOpType, OptRelNode, UnOpType,
};
use crate::properties::column_ref::{ColumnRefPropertyBuilder, GroupColumnRefs};
use crate::{
Expand All @@ -11,8 +11,8 @@ use crate::{
use arrow_schema::{ArrowError, DataType};
use datafusion::arrow::array::{
Array, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array, Int16Array,
Int32Array, Int8Array, RecordBatch, RecordBatchIterator, RecordBatchReader, UInt16Array,
UInt32Array, UInt8Array,
Int32Array, Int8Array, RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray,
UInt16Array, UInt32Array, UInt8Array,
};
use itertools::Itertools;
use optd_core::{
Expand All @@ -22,6 +22,7 @@ use optd_core::{
};
use optd_gungnir::stats::hyperloglog::{self, HyperLogLog};
use optd_gungnir::stats::tdigest::{self, TDigest};
use optd_gungnir::utils::arith_encoder;
use serde::{Deserialize, Serialize};

fn compute_plan_node_cost<T: RelNodeTyp, C: CostModel<T>>(
Expand Down Expand Up @@ -181,6 +182,7 @@ impl DataFusionPerTableStats {
| DataType::UInt32
| DataType::Float32
| DataType::Float64
| DataType::Utf8
)
}

Expand Down Expand Up @@ -222,6 +224,10 @@ impl DataFusionPerTableStats {
val as f64
}

fn str_to_f64(string: &str) -> f64 {
arith_encoder::encode(string)
}

match col_type {
DataType::Boolean => {
generate_stats_for_col!({ col, distr, hll, BooleanArray, to_f64_safe })
Expand Down Expand Up @@ -256,6 +262,9 @@ impl DataFusionPerTableStats {
DataType::Decimal128(_, _) => {
generate_stats_for_col!({ col, distr, hll, Decimal128Array, i128_to_f64 })
}
DataType::Utf8 => {
generate_stats_for_col!({ col, distr, hll, StringArray, str_to_f64 })
}
_ => unreachable!(),
}
}
Expand Down Expand Up @@ -323,6 +332,10 @@ const DEFAULT_EQ_SEL: f64 = 0.005;
const DEFAULT_INEQ_SEL: f64 = 0.3333333333333333;
// Default selectivity estimate for pattern-match operators such as LIKE
const DEFAULT_MATCH_SEL: f64 = 0.005;
// Default selectivity if we have no information
const DEFAULT_UNK_SEL: f64 = 0.005;
// Default n-distinct estimate for derived columns or columns lacking statistics
const DEFAULT_N_DISTINCT: u64 = 200;

const INVALID_SEL: f64 = 0.01;

Expand Down Expand Up @@ -401,37 +414,33 @@ impl<M: MostCommonValues, D: Distribution> CostModel<OptRelNodeTyp> for OptCostM
OptRelNodeTyp::PhysicalEmptyRelation => Self::cost(0.5, 0.01, 0.0),
OptRelNodeTyp::PhysicalLimit => {
let (row_cnt, compute_cost, _) = Self::cost_tuple(&children[0]);
let row_cnt = if let Some(context) = context {
if let Some(optimizer) = optimizer {
let mut fetch_expr =
optimizer.get_all_group_bindings(context.children_group_ids[2], false);
assert!(
fetch_expr.len() == 1,
"fetch expression should be the only expr in the group"
);
let fetch_expr = fetch_expr.pop().unwrap();
assert!(
matches!(
fetch_expr.typ,
OptRelNodeTyp::Constant(ConstantType::UInt64)
),
"fetch type can only be UInt64"
);
let fetch = ConstantExpr::from_rel_node(fetch_expr)
.unwrap()
.value()
.as_u64();
// u64::MAX represents None
if fetch == u64::MAX {
row_cnt
} else {
row_cnt.min(fetch as f64)
}
let row_cnt = if let (Some(context), Some(optimizer)) = (context, optimizer) {
let mut fetch_expr =
optimizer.get_all_group_bindings(context.children_group_ids[2], false);
assert!(
fetch_expr.len() == 1,
"fetch expression should be the only expr in the group"
);
let fetch_expr = fetch_expr.pop().unwrap();
assert!(
matches!(
fetch_expr.typ,
OptRelNodeTyp::Constant(ConstantType::UInt64)
),
"fetch type can only be UInt64"
);
let fetch = ConstantExpr::from_rel_node(fetch_expr)
.unwrap()
.value()
.as_u64();
// u64::MAX represents None
if fetch == u64::MAX {
row_cnt
} else {
panic!("compute_cost() should not be called if optimizer is None")
row_cnt.min(fetch as f64)
}
} else {
panic!("compute_cost() should not be called if context is None")
(row_cnt * DEFAULT_UNK_SEL).max(1.0)
};
Self::cost(row_cnt, compute_cost, 0.0)
}
Expand Down Expand Up @@ -499,10 +508,15 @@ impl<M: MostCommonValues, D: Distribution> CostModel<OptRelNodeTyp> for OptCostM
Self::cost(row_cnt, row_cnt * row_cnt.ln_1p().max(1.0), 0.0)
}
OptRelNodeTyp::PhysicalAgg => {
let (row_cnt, _, _) = Self::cost_tuple(&children[0]);
let child_row_cnt = Self::row_cnt(&children[0]);
let row_cnt = self.get_agg_row_cnt(context, optimizer, child_row_cnt);
let (_, compute_cost_1, _) = Self::cost_tuple(&children[1]);
let (_, compute_cost_2, _) = Self::cost_tuple(&children[2]);
Self::cost(row_cnt, row_cnt * (compute_cost_1 + compute_cost_2), 0.0)
Self::cost(
row_cnt,
child_row_cnt * (compute_cost_1 + compute_cost_2),
0.0,
)
}
OptRelNodeTyp::List => {
let compute_cost = children
Expand Down Expand Up @@ -544,6 +558,58 @@ impl<M: MostCommonValues, D: Distribution> OptCostModel<M, D> {
}
}

fn get_agg_row_cnt(
&self,
context: Option<RelNodeContext>,
optimizer: Option<&CascadesOptimizer<OptRelNodeTyp>>,
child_row_cnt: f64,
) -> f64 {
if let (Some(context), Some(optimizer)) = (context, optimizer) {
let group_by_id = context.children_group_ids[2];
let mut group_by_exprs: Vec<Arc<RelNode<OptRelNodeTyp>>> =
optimizer.get_all_group_bindings(group_by_id, false);
assert!(
group_by_exprs.len() == 1,
"ExprList expression should be the only expression in the GROUP BY group"
);
let group_by = group_by_exprs.pop().unwrap();
let group_by = ExprList::from_rel_node(group_by).unwrap();
if group_by.is_empty() {
1.0
} else {
// Multiply the n-distinct of all the group by columns.
// TODO: improve with multi-dimensional n-distinct
let base_table_col_refs = optimizer
.get_property_by_group::<ColumnRefPropertyBuilder>(context.group_id, 1);
base_table_col_refs
.iter()
.take(group_by.len())
.map(|col_ref| match col_ref {
ColumnRef::BaseTableColumnRef { table, col_idx } => {
let table_stats = self.per_table_stats_map.get(table);
let column_stats = table_stats.map(|table_stats| {
table_stats.per_column_stats_vec.get(*col_idx).unwrap()
});

if let Some(Some(column_stats)) = column_stats {
column_stats.ndistinct as f64
} else {
// The column type is not supported or stats are missing.
DEFAULT_N_DISTINCT as f64
}
}
ColumnRef::Derived => DEFAULT_N_DISTINCT as f64,
_ => panic!(
"GROUP BY base table column ref must either be derived or base table"
),
})
.product()
}
} else {
(child_row_cnt * DEFAULT_UNK_SEL).max(1.0)
}
}

/// The expr_tree input must be a "mixed expression tree"
/// An "expression node" refers to a RelNode that returns true for is_expression()
/// A "full expression tree" is where every node in the tree is an expression node
Expand Down
34 changes: 25 additions & 9 deletions optd-datafusion-repr/src/properties/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,26 @@ use std::sync::Arc;
use optd_core::property::PropertyBuilder;

use super::DEFAULT_NAME;
use crate::plan_nodes::{ConstantType, EmptyRelationData, OptRelNodeTyp};
use crate::plan_nodes::{ConstantType, EmptyRelationData, FuncType, OptRelNodeTyp};

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Field {
pub name: String,
pub typ: ConstantType,
pub nullable: bool,
}

impl Field {
/// Generate a field that is only a place holder whose members are never used.
fn placeholder() -> Self {
Self {
name: DEFAULT_NAME.to_string(),
typ: ConstantType::Any,
nullable: true,
}
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Schema {
pub fields: Vec<Field>,
Expand Down Expand Up @@ -87,14 +99,18 @@ impl PropertyBuilder<OptRelNodeTyp> for SchemaPropertyBuilder {
Schema { fields }
}
OptRelNodeTyp::LogOp(_) => Schema {
fields: vec![
Field {
name: DEFAULT_NAME.to_string(),
typ: ConstantType::Any,
nullable: true
};
children.len()
],
fields: vec![Field::placeholder(); children.len()],
},
OptRelNodeTyp::Agg => {
let mut group_by_schema = children[1].clone();
let agg_schema = children[2].clone();
group_by_schema.fields.extend(agg_schema.fields);
group_by_schema
}
OptRelNodeTyp::Func(FuncType::Agg(_)) => Schema {
// TODO: this is just a place holder now.
// The real type should be the column type.
fields: vec![Field::placeholder()],
},
_ => Schema { fields: vec![] },
}
Expand Down
9 changes: 8 additions & 1 deletion optd-gungnir/src/stats/hyperloglog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,17 @@ pub struct HyperLogLog {
alpha: f64, // The normal HLL multiplier factor.
}

// Serialize common data types for hashing (&str).
impl ByteSerializable for &str {
fn to_bytes(&self) -> Vec<u8> {
self.as_bytes().to_vec()
}
}

// Serialize common data types for hashing (String).
impl ByteSerializable for String {
fn to_bytes(&self) -> Vec<u8> {
self.as_bytes().to_vec()
self.as_str().to_bytes()
}
}

Expand Down
1 change: 0 additions & 1 deletion optd-sqlplannertest/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ impl DatafusionDBMS {
task: &str,
flags: &[String],
) -> Result<()> {
println!("task_explain(): called on sql={}", sql);
use std::fmt::Write;

let with_logical = flags.contains(&"with_logical".to_string());
Expand Down
Loading
Loading