From 81dfe731638e99aed9f8a85577cd86b8b4afba84 Mon Sep 17 00:00:00 2001 From: Avery Date: Mon, 12 Feb 2024 16:17:59 -0500 Subject: [PATCH] [Logical Optimizer] align schema (#61) https://github.com/cmu-db/optd/issues/56 --------- Signed-off-by: AveryQi115 --- optd-datafusion-bridge/src/from_optd.rs | 20 ++++++------ optd-datafusion-bridge/src/lib.rs | 17 +++++++--- optd-datafusion-repr/src/properties/schema.rs | 32 +++++++++++++++---- optd-datafusion-repr/src/rules/joins.rs | 12 +++---- 4 files changed, 54 insertions(+), 27 deletions(-) diff --git a/optd-datafusion-bridge/src/from_optd.rs b/optd-datafusion-bridge/src/from_optd.rs index 4d8eca38..df521716 100644 --- a/optd-datafusion-bridge/src/from_optd.rs +++ b/optd-datafusion-bridge/src/from_optd.rs @@ -36,7 +36,7 @@ use crate::{physical_collector::CollectorExec, OptdPlanContext}; // TODO: current DataType and ConstantType are not 1 to 1 mapping // optd schema stores constantType from data type in catalog.get // for decimal128, the precision is lost -fn from_optd_schema(optd_schema: &OptdSchema) -> Schema { +fn from_optd_schema(optd_schema: OptdSchema) -> Schema { let match_type = |typ: &ConstantType| match typ { ConstantType::Any => unimplemented!(), ConstantType::Bool => DataType::Boolean, @@ -52,12 +52,14 @@ fn from_optd_schema(optd_schema: &OptdSchema) -> Schema { ConstantType::Decimal => DataType::Float64, ConstantType::Utf8String => DataType::Utf8, }; - let fields: Vec<_> = optd_schema - .0 - .iter() - .enumerate() - .map(|(i, typ)| Field::new(format!("c{}", i), match_type(typ), false)) - .collect(); + let mut fields = Vec::with_capacity(optd_schema.len()); + for field in optd_schema.fields { + fields.push(Field::new( + field.name, + match_type(&field.typ), + field.nullable, + )); + } Schema::new(fields) } @@ -437,7 +439,7 @@ impl OptdPlanContext<'_> { #[async_recursion] async fn conv_from_optd_plan_node(&mut self, node: PlanNode) -> Result> { - let mut schema = OptdSchema(vec![]); + let mut schema = OptdSchema { fields: vec![] }; if node.typ() == OptRelNodeTyp::PhysicalEmptyRelation { schema = node.schema(self.optimizer.unwrap().optd_optimizer()); } @@ -485,7 +487,7 @@ impl OptdPlanContext<'_> { } OptRelNodeTyp::PhysicalEmptyRelation => { let physical_node = PhysicalEmptyRelation::from_rel_node(rel_node).unwrap(); - let datafusion_schema: Schema = from_optd_schema(&schema); + let datafusion_schema: Schema = from_optd_schema(schema); Ok(Arc::new(datafusion::physical_plan::empty::EmptyExec::new( physical_node.produce_one_row(), Arc::new(datafusion_schema), diff --git a/optd-datafusion-bridge/src/lib.rs b/optd-datafusion-bridge/src/lib.rs index cac9074d..601e79b8 100644 --- a/optd-datafusion-bridge/src/lib.rs +++ b/optd-datafusion-bridge/src/lib.rs @@ -61,9 +61,10 @@ impl Catalog for DatafusionCatalog { let catalog = self.catalog.catalog("datafusion").unwrap(); let schema = catalog.schema("public").unwrap(); let table = futures_lite::future::block_on(schema.table(name.as_ref())).unwrap(); - let fields = table.schema(); - let mut optd_schema = vec![]; - for field in fields.fields() { + let schema = table.schema(); + let fields = schema.fields(); + let mut optd_fields = Vec::with_capacity(fields.len()); + for field in fields { let dt = match field.data_type() { DataType::Date32 => ConstantType::Date, DataType::Int32 => ConstantType::Int32, @@ -73,9 +74,15 @@ impl Catalog for DatafusionCatalog { DataType::Decimal128(_, _) => ConstantType::Decimal, dt => unimplemented!("{:?}", dt), }; - optd_schema.push(dt); + optd_fields.push(optd_datafusion_repr::properties::schema::Field { + name: field.name().to_string(), + typ: dt, + nullable: field.is_nullable(), + }); + } + optd_datafusion_repr::properties::schema::Schema { + fields: optd_fields, } - optd_datafusion_repr::properties::schema::Schema(optd_schema) } } diff --git a/optd-datafusion-repr/src/properties/schema.rs b/optd-datafusion-repr/src/properties/schema.rs index 09ff2eee..6ef5872c 100644 --- a/optd-datafusion-repr/src/properties/schema.rs +++ b/optd-datafusion-repr/src/properties/schema.rs @@ -3,12 +3,19 @@ use optd_core::property::PropertyBuilder; use crate::plan_nodes::{ConstantType, OptRelNodeTyp}; #[derive(Clone, Debug)] -pub struct Schema(pub Vec); +pub struct Field { + pub name: String, + pub typ: ConstantType, + pub nullable: bool, +} +#[derive(Clone, Debug)] +pub struct Schema { + pub fields: Vec, +} -// TODO: add names, nullable to schema impl Schema { pub fn len(&self) -> usize { - self.0.len() + self.fields.len() } pub fn is_empty(&self) -> bool { @@ -48,11 +55,24 @@ impl PropertyBuilder for SchemaPropertyBuilder { OptRelNodeTyp::Filter => children[0].clone(), OptRelNodeTyp::Join(_) => { let mut schema = children[0].clone(); - schema.0.extend(children[1].clone().0); + let schema2 = children[1].clone(); + schema.fields.extend(schema2.fields); schema } - OptRelNodeTyp::List => Schema(vec![ConstantType::Any; children.len()]), - _ => Schema(vec![]), + OptRelNodeTyp::List => { + // TODO: calculate real is_nullable for aggregations + Schema { + fields: vec![ + Field { + name: "unnamed".to_string(), + typ: ConstantType::Any, + nullable: true + }; + children.len() + ], + } + } + _ => Schema { fields: vec![] }, } } diff --git a/optd-datafusion-repr/src/rules/joins.rs b/optd-datafusion-repr/src/rules/joins.rs index 5194d194..a6453e3c 100644 --- a/optd-datafusion-repr/src/rules/joins.rs +++ b/optd-datafusion-repr/src/rules/joins.rs @@ -68,7 +68,7 @@ fn apply_join_commute( cond, JoinType::Inner, ); - let mut proj_expr = Vec::with_capacity(left_schema.0.len() + right_schema.0.len()); + let mut proj_expr = Vec::with_capacity(left_schema.len() + right_schema.len()); for i in 0..left_schema.len() { proj_expr.push(ColumnRefExpr::new(right_schema.len() + i).into_expr()); } @@ -218,13 +218,11 @@ fn apply_hash_join( let Some(mut right_expr) = ColumnRefExpr::from_rel_node(right_expr.into_rel_node()) else { return vec![]; }; - let can_convert = if left_expr.index() < left_schema.0.len() - && right_expr.index() >= left_schema.0.len() + let can_convert = if left_expr.index() < left_schema.len() + && right_expr.index() >= left_schema.len() { true - } else if right_expr.index() < left_schema.0.len() - && left_expr.index() >= left_schema.0.len() - { + } else if right_expr.index() < left_schema.len() && left_expr.index() >= left_schema.len() { (left_expr, right_expr) = (right_expr, left_expr); true } else { @@ -232,7 +230,7 @@ fn apply_hash_join( }; if can_convert { - let right_expr = ColumnRefExpr::new(right_expr.index() - left_schema.0.len()); + let right_expr = ColumnRefExpr::new(right_expr.index() - left_schema.len()); let node = PhysicalHashJoin::new( PlanNode::from_group(left.into()), PlanNode::from_group(right.into()),