From aa98fd93f572d56e9bac917876b9d7278cc0598f Mon Sep 17 00:00:00 2001 From: Alan Cai Date: Fri, 4 Aug 2023 15:52:20 -0700 Subject: [PATCH] Add SQL aggregates ANY, SOME, EVERY and their COLL_ versions --- CHANGELOG.md | 1 + partiql-eval/src/eval/eval_expr_wrapper.rs | 18 ++- partiql-eval/src/eval/evaluable.rs | 147 ++++++++++++++++----- partiql-eval/src/eval/expr/coll.rs | 86 ++++++++++-- partiql-eval/src/plan.rs | 24 +++- partiql-logical-planner/src/builtins.rs | 72 ++++++++++ partiql-logical-planner/src/lower.rs | 14 +- partiql-logical/src/lib.rs | 6 + partiql-parser/src/parse/parser_state.rs | 3 +- partiql-parser/src/preprocessor.rs | 2 +- partiql-types/src/lib.rs | 4 +- 11 files changed, 324 insertions(+), 53 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d36e985..812a6e95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add `partiql_ast_passes::static_typer` for type annotating the AST. - Add ability to parse `ORDER BY`, `LIMIT`, `OFFSET` in children of set operators - Add `OUTER` bag operator (`OUTER UNION`, `OUTER INTERSECT`, `OUTER EXCEPT`) implementation +- Implements the aggregation functions `ANY`, `SOME`, `EVERY` and their `COLL_` versions ### Fixes - Fixes parsing of multiple consecutive path wildcards (e.g. `a[*][*][*]`), unpivot (e.g. `a.*.*.*`), and path expressions (e.g. `a[1 + 2][3 + 4][5 + 6]`)—previously these would not parse correctly. diff --git a/partiql-eval/src/eval/eval_expr_wrapper.rs b/partiql-eval/src/eval/eval_expr_wrapper.rs index 975d61db..a13c2b5d 100644 --- a/partiql-eval/src/eval/eval_expr_wrapper.rs +++ b/partiql-eval/src/eval/eval_expr_wrapper.rs @@ -20,6 +20,8 @@ use std::ops::ControlFlow; #[inline] pub(crate) fn subsumes(typ: &PartiqlType, value: &Value) -> bool { match (typ.kind(), value) { + (_, Value::Null) => true, + (_, Value::Missing) => true, (TypeKind::Any, _) => true, (TypeKind::AnyOf(anyof), val) => anyof.types().any(|typ| subsumes(typ, val)), (TypeKind::Null, Value::Null) => true, @@ -36,10 +38,22 @@ pub(crate) fn subsumes(typ: &PartiqlType, value: &Value) -> bool { Value::String(_), ) => true, (TypeKind::Struct(_), Value::Tuple(_)) => true, - (TypeKind::Bag(_), Value::Bag(_)) => true, + (TypeKind::Bag(b_type), Value::Bag(b_values)) => { + let bag_element_type = b_type.element_type.as_ref(); + let mut b_values = b_values.iter(); + b_values.all(|b_value| { + println!("b_value: {:?}", b_value); + println!("bag_element_type: {:?}", bag_element_type); + subsumes(bag_element_type, b_value) + }) + } (TypeKind::DateTime, Value::DateTime(_)) => true, - (TypeKind::Array(_), Value::List(_)) => true, + (TypeKind::Array(a_type), Value::List(l_values)) => { + let array_element_type = a_type.element_type.as_ref(); + let mut l_values = l_values.iter(); + l_values.all(|l_value| subsumes(array_element_type, l_value)) + } _ => false, } } diff --git a/partiql-eval/src/eval/evaluable.rs b/partiql-eval/src/eval/evaluable.rs index 6b7d97ee..5aec02b3 100644 --- a/partiql-eval/src/eval/evaluable.rs +++ b/partiql-eval/src/eval/evaluable.rs @@ -351,6 +351,8 @@ pub(crate) enum AggFunc { Max(Max), Min(Min), Sum(Sum), + Any(Any), + Every(Every), } impl AggregateFunction for AggFunc { @@ -361,6 +363,8 @@ impl AggregateFunction for AggFunc { AggFunc::Max(v) => v.next_value(input_value, group), AggFunc::Min(v) => v.next_value(input_value, group), AggFunc::Sum(v) => v.next_value(input_value, group), + AggFunc::Any(v) => v.next_value(input_value, group), + AggFunc::Every(v) => v.next_value(input_value, group), } } @@ -371,6 +375,8 @@ impl AggregateFunction for AggFunc { AggFunc::Max(v) => v.compute(group), AggFunc::Min(v) => v.compute(group), AggFunc::Sum(v) => v.compute(group), + AggFunc::Any(v) => v.compute(group), + AggFunc::Every(v) => v.compute(group), } } } @@ -454,7 +460,7 @@ impl Avg { impl AggregateFunction for Avg { fn next_value(&mut self, input_value: &Value, group: &Tuple) { - if !input_value.is_absent() && self.aggregator.filter_value(input_value.clone(), group) { + if input_value.is_present() && self.aggregator.filter_value(input_value.clone(), group) { match self.avgs.get_mut(group) { None => { self.avgs.insert(group.clone(), (1, input_value.clone())); @@ -468,12 +474,9 @@ impl AggregateFunction for Avg { } fn compute(&self, group: &Tuple) -> Result { - match self.avgs.get(group) { - None => Err(EvaluationError::IllegalState( - "Expect group to exist in avgs".to_string(), - )), - Some((0, _)) => Ok(Null), - Some((c, s)) => Ok(s / &Value::from(rust_decimal::Decimal::from(*c))), + match self.avgs.get(group).unwrap_or(&(0, Null)) { + (0, _) => Ok(Null), + (c, s) => Ok(s / &Value::from(rust_decimal::Decimal::from(*c))), } } } @@ -503,7 +506,7 @@ impl Count { impl AggregateFunction for Count { fn next_value(&mut self, input_value: &Value, group: &Tuple) { - if !input_value.is_absent() && self.aggregator.filter_value(input_value.clone(), group) { + if input_value.is_present() && self.aggregator.filter_value(input_value.clone(), group) { match self.counts.get_mut(group) { None => { self.counts.insert(group.clone(), 1); @@ -516,12 +519,7 @@ impl AggregateFunction for Count { } fn compute(&self, group: &Tuple) -> Result { - match self.counts.get(group) { - None => Err(EvaluationError::IllegalState( - "Expect group to exist in counts".to_string(), - )), - Some(val) => Ok(Value::from(val)), - } + Ok(Value::from(self.counts.get(group).unwrap_or(&0))) } } @@ -550,7 +548,7 @@ impl Max { impl AggregateFunction for Max { fn next_value(&mut self, input_value: &Value, group: &Tuple) { - if !input_value.is_absent() && self.aggregator.filter_value(input_value.clone(), group) { + if input_value.is_present() && self.aggregator.filter_value(input_value.clone(), group) { match self.maxes.get_mut(group) { None => { self.maxes.insert(group.clone(), input_value.clone()); @@ -563,12 +561,7 @@ impl AggregateFunction for Max { } fn compute(&self, group: &Tuple) -> Result { - match self.maxes.get(group) { - None => Err(EvaluationError::IllegalState( - "Expect group to exist in maxes".to_string(), - )), - Some(val) => Ok(val.clone()), - } + Ok(self.maxes.get(group).unwrap_or(&Null).clone()) } } @@ -597,7 +590,7 @@ impl Min { impl AggregateFunction for Min { fn next_value(&mut self, input_value: &Value, group: &Tuple) { - if !input_value.is_absent() && self.aggregator.filter_value(input_value.clone(), group) { + if input_value.is_present() && self.aggregator.filter_value(input_value.clone(), group) { match self.mins.get_mut(group) { None => { self.mins.insert(group.clone(), input_value.clone()); @@ -610,12 +603,7 @@ impl AggregateFunction for Min { } fn compute(&self, group: &Tuple) -> Result { - match self.mins.get(group) { - None => Err(EvaluationError::IllegalState( - "Expect group to exist in mins".to_string(), - )), - Some(val) => Ok(val.clone()), - } + Ok(self.mins.get(group).unwrap_or(&Null).clone()) } } @@ -644,7 +632,7 @@ impl Sum { impl AggregateFunction for Sum { fn next_value(&mut self, input_value: &Value, group: &Tuple) { - if !input_value.is_absent() && self.aggregator.filter_value(input_value.clone(), group) { + if input_value.is_present() && self.aggregator.filter_value(input_value.clone(), group) { match self.sums.get_mut(group) { None => { self.sums.insert(group.clone(), input_value.clone()); @@ -657,13 +645,104 @@ impl AggregateFunction for Sum { } fn compute(&self, group: &Tuple) -> Result { - match self.sums.get(group) { - None => Err(EvaluationError::IllegalState( - "Expect group to exist in sums".to_string(), - )), - Some(val) => Ok(val.clone()), + Ok(self.sums.get(group).unwrap_or(&Null).clone()) + } +} + +/// Represents SQL's `ANY`/`SOME` aggregation function +#[derive(Debug)] +pub(crate) struct Any { + anys: HashMap, + aggregator: AggFilterFn, +} + +impl Any { + pub(crate) fn new_distinct() -> Self { + Any { + anys: HashMap::new(), + aggregator: AggFilterFn::Distinct(AggFilterDistinct::new()), + } + } + + pub(crate) fn new_all() -> Self { + Any { + anys: HashMap::new(), + aggregator: AggFilterFn::default(), + } + } +} + +impl AggregateFunction for Any { + fn next_value(&mut self, input_value: &Value, group: &Tuple) { + if input_value.is_present() && self.aggregator.filter_value(input_value.clone(), group) { + match self.anys.get_mut(group) { + None => { + match input_value { + Boolean(_) => self.anys.insert(group.clone(), input_value.clone()), + _ => self.anys.insert(group.clone(), Missing), + }; + } + Some(acc) => { + *acc = match (acc.clone(), input_value) { + (Boolean(l), Value::Boolean(r)) => Value::Boolean(l || *r), + (_, _) => Missing, + }; + } + } } } + + fn compute(&self, group: &Tuple) -> Result { + Ok(self.anys.get(group).unwrap_or(&Null).clone()) + } +} + +/// Represents SQL's `EVERY` aggregation function +#[derive(Debug)] +pub(crate) struct Every { + everys: HashMap, + aggregator: AggFilterFn, +} + +impl Every { + pub(crate) fn new_distinct() -> Self { + Every { + everys: HashMap::new(), + aggregator: AggFilterFn::Distinct(AggFilterDistinct::new()), + } + } + + pub(crate) fn new_all() -> Self { + Every { + everys: HashMap::new(), + aggregator: AggFilterFn::default(), + } + } +} + +impl AggregateFunction for Every { + fn next_value(&mut self, input_value: &Value, group: &Tuple) { + if input_value.is_present() && self.aggregator.filter_value(input_value.clone(), group) { + match self.everys.get_mut(group) { + None => { + match input_value { + Boolean(_) => self.everys.insert(group.clone(), input_value.clone()), + _ => self.everys.insert(group.clone(), Missing), + }; + } + Some(acc) => { + *acc = match (acc.clone(), input_value) { + (Boolean(l), Value::Boolean(r)) => Value::Boolean(l && *r), + (_, _) => Missing, + }; + } + } + } + } + + fn compute(&self, group: &Tuple) -> Result { + Ok(self.everys.get(group).unwrap_or(&Null).clone()) + } } /// Represents an evaluation `GROUP BY` operator. For `GROUP BY` operational semantics, see section diff --git a/partiql-eval/src/eval/expr/coll.rs b/partiql-eval/src/eval/expr/coll.rs index 37c02a5e..89c2ab2e 100644 --- a/partiql-eval/src/eval/expr/coll.rs +++ b/partiql-eval/src/eval/expr/coll.rs @@ -4,9 +4,9 @@ use crate::eval::expr::{BindError, BindEvalExpr, EvalExpr}; use itertools::{Itertools, Unique}; -use partiql_types::{ArrayType, BagType, PartiqlType, TypeKind, TYPE_MISSING}; +use partiql_types::{ArrayType, BagType, PartiqlType, TypeKind, TYPE_BOOL, TYPE_NUMERIC_TYPES}; use partiql_value::Value::{Missing, Null}; -use partiql_value::{Value, ValueIter}; +use partiql_value::{BinaryAnd, BinaryOr, Value, ValueIter}; use std::fmt::Debug; use std::hash::Hash; @@ -26,6 +26,10 @@ pub(crate) enum EvalCollFn { Min(SetQuantifier), /// Represents the `COLL_SUM` function, e.g. `COLL_SUM(DISTINCT [1, 2, 2, 3])`. Sum(SetQuantifier), + /// Represents the `COLL_ANY`/`COLL_SOME` function, e.g. `COLL_ANY(DISTINCT [true, true, false])`. + Any(SetQuantifier), + /// Represents the `COLL_EVERY` function, e.g. `COLL_EVERY(DISTINCT [true, true, false])`. + Every(SetQuantifier), } impl BindEvalExpr for EvalCollFn { @@ -34,26 +38,56 @@ impl BindEvalExpr for EvalCollFn { args: Vec>, ) -> Result, BindError> { fn create( + types: [PartiqlType; 1], args: Vec>, f: F, ) -> Result, BindError> where F: Fn(ValueIter) -> Value + 'static, { - let list = PartiqlType::new(TypeKind::Array(ArrayType::new_any())); - let bag = PartiqlType::new(TypeKind::Bag(BagType::new_any())); - let types = [PartiqlType::any_of([list, bag, TYPE_MISSING])]; UnaryValueExpr::create_typed::<{ STRICT }, _>(types, args, move |value| { value.sequence_iter().map(&f).unwrap_or(Missing) }) } + let boolean_elems = [PartiqlType::any_of([ + PartiqlType::new(TypeKind::Array(ArrayType::new(Box::new(TYPE_BOOL)))), + PartiqlType::new(TypeKind::Bag(BagType::new(Box::new(TYPE_BOOL)))), + ])]; + let numeric_elems = [PartiqlType::any_of([ + PartiqlType::new(TypeKind::Array(ArrayType::new(Box::new( + PartiqlType::any_of(TYPE_NUMERIC_TYPES), + )))), + PartiqlType::new(TypeKind::Bag(BagType::new(Box::new(PartiqlType::any_of( + TYPE_NUMERIC_TYPES, + ))))), + ])]; + let any_elems = [PartiqlType::any_of([ + PartiqlType::new(TypeKind::Array(ArrayType::new_any())), + PartiqlType::new(TypeKind::Bag(BagType::new_any())), + ])]; match *self { - EvalCollFn::Count(setq) => create::<{ STRICT }, _>(args, move |it| it.coll_count(setq)), - EvalCollFn::Avg(setq) => create::<{ STRICT }, _>(args, move |it| it.coll_avg(setq)), - EvalCollFn::Max(setq) => create::<{ STRICT }, _>(args, move |it| it.coll_max(setq)), - EvalCollFn::Min(setq) => create::<{ STRICT }, _>(args, move |it| it.coll_min(setq)), - EvalCollFn::Sum(setq) => create::<{ STRICT }, _>(args, move |it| it.coll_sum(setq)), + EvalCollFn::Count(setq) => { + create::<{ STRICT }, _>(any_elems, args, move |it| it.coll_count(setq)) + } + EvalCollFn::Avg(setq) => { + create::<{ STRICT }, _>(numeric_elems, args, move |it| it.coll_avg(setq)) + } + EvalCollFn::Max(setq) => { + create::<{ STRICT }, _>(any_elems, args, move |it| it.coll_max(setq)) + } + EvalCollFn::Min(setq) => { + create::<{ STRICT }, _>(any_elems, args, move |it| it.coll_min(setq)) + } + EvalCollFn::Sum(setq) => { + create::<{ STRICT }, _>(numeric_elems, args, move |it| it.coll_sum(setq)) + } + EvalCollFn::Any(setq) => { + create::<{ STRICT }, _>(boolean_elems, args, move |it| it.coll_any(setq)) + } + EvalCollFn::Every(setq) => { + create::<{ STRICT }, _>(boolean_elems, args, move |it| it.coll_every(setq)) + } } } } @@ -181,6 +215,38 @@ trait CollIterator<'a>: Iterator { Null } } + + #[inline] + fn coll_any(self, setq: SetQuantifier) -> Value + where + Self: Sized, + { + self.filter(|e| e.is_present()) + .set_quantified(setq) + .coll_reduce_or(Null, |prev, x| { + if let Value::Boolean(_) = x { + ControlFlow::Continue(prev.or(x)) + } else { + ControlFlow::Break(Missing) + } + }) + } + + #[inline] + fn coll_every(self, setq: SetQuantifier) -> Value + where + Self: Sized, + { + self.filter(|e| e.is_present()) + .set_quantified(setq) + .coll_reduce_or(Null, |prev, x| { + if let Value::Boolean(_) = x { + ControlFlow::Continue(prev.and(x)) + } else { + ControlFlow::Break(Missing) + } + }) + } } /// [`Iterator`] helper methods for `COLL_*` operators for reducing values to a single value while diff --git a/partiql-eval/src/plan.rs b/partiql-eval/src/plan.rs index 50c2158a..eec67d63 100644 --- a/partiql-eval/src/plan.rs +++ b/partiql-eval/src/plan.rs @@ -13,9 +13,9 @@ use partiql_logical::{ use crate::error::{ErrorNode, PlanErr, PlanningError}; use crate::eval; use crate::eval::evaluable::{ - Avg, Count, EvalGroupingStrategy, EvalJoinKind, EvalOrderBy, EvalOrderBySortCondition, + Any, Avg, Count, EvalGroupingStrategy, EvalJoinKind, EvalOrderBy, EvalOrderBySortCondition, EvalOrderBySortSpec, EvalOuterExcept, EvalOuterIntersect, EvalOuterUnion, EvalSubQueryExpr, - Evaluable, Max, Min, Sum, + Evaluable, Every, Max, Min, Sum, }; use crate::eval::expr::{ BindError, BindEvalExpr, EvalBagExpr, EvalBetweenExpr, EvalCollFn, EvalDynamicLookup, EvalExpr, @@ -259,6 +259,12 @@ impl<'c> EvaluatorPlanner<'c> { (AggFunc::AggSum, logical::SetQuantifier::All) => { eval::evaluable::AggFunc::Sum(Sum::new_all()) } + (AggFunc::AggAny, logical::SetQuantifier::All) => { + eval::evaluable::AggFunc::Any(Any::new_all()) + } + (AggFunc::AggEvery, logical::SetQuantifier::All) => { + eval::evaluable::AggFunc::Every(Every::new_all()) + } (AggFunc::AggAvg, logical::SetQuantifier::Distinct) => { eval::evaluable::AggFunc::Avg(Avg::new_distinct()) } @@ -274,6 +280,12 @@ impl<'c> EvaluatorPlanner<'c> { (AggFunc::AggSum, logical::SetQuantifier::Distinct) => { eval::evaluable::AggFunc::Sum(Sum::new_distinct()) } + (AggFunc::AggAny, logical::SetQuantifier::Distinct) => { + eval::evaluable::AggFunc::Any(Any::new_distinct()) + } + (AggFunc::AggEvery, logical::SetQuantifier::Distinct) => { + eval::evaluable::AggFunc::Every(Every::new_distinct()) + } }; eval::evaluable::AggregateExpression { name: a_e.name.to_string(), @@ -715,6 +727,14 @@ impl<'c> EvaluatorPlanner<'c> { "coll_sum", EvalCollFn::Sum(setq.into()).bind::<{ STRICT }>(args), ), + CallName::CollAny(setq) => ( + "coll_any", + EvalCollFn::Any(setq.into()).bind::<{ STRICT }>(args), + ), + CallName::CollEvery(setq) => ( + "coll_every", + EvalCollFn::Every(setq.into()).bind::<{ STRICT }>(args), + ), CallName::ByName(name) => match self.catalog.get_function(name) { None => { self.errors.push(PlanningError::IllegalState(format!( diff --git a/partiql-logical-planner/src/builtins.rs b/partiql-logical-planner/src/builtins.rs index e9fe7289..86e83bcc 100644 --- a/partiql-logical-planner/src/builtins.rs +++ b/partiql-logical-planner/src/builtins.rs @@ -655,6 +655,76 @@ fn function_call_def_coll_sum() -> CallDef { } } +fn function_call_def_coll_any() -> CallDef { + CallDef { + names: vec!["coll_any", "coll_some"], + overloads: vec![ + CallSpec { + input: vec![CallSpecArg::Positional], + output: Box::new(|args| { + logical::ValueExpr::Call(logical::CallExpr { + name: logical::CallName::CollAny(SetQuantifier::All), + arguments: args, + }) + }), + }, + CallSpec { + input: vec![CallSpecArg::Named("all".into())], + output: Box::new(|args| { + logical::ValueExpr::Call(logical::CallExpr { + name: logical::CallName::CollAny(SetQuantifier::All), + arguments: args, + }) + }), + }, + CallSpec { + input: vec![CallSpecArg::Named("distinct".into())], + output: Box::new(|args| { + logical::ValueExpr::Call(logical::CallExpr { + name: logical::CallName::CollAny(SetQuantifier::Distinct), + arguments: args, + }) + }), + }, + ], + } +} + +fn function_call_def_coll_every() -> CallDef { + CallDef { + names: vec!["coll_every"], + overloads: vec![ + CallSpec { + input: vec![CallSpecArg::Positional], + output: Box::new(|args| { + logical::ValueExpr::Call(logical::CallExpr { + name: logical::CallName::CollEvery(SetQuantifier::All), + arguments: args, + }) + }), + }, + CallSpec { + input: vec![CallSpecArg::Named("all".into())], + output: Box::new(|args| { + logical::ValueExpr::Call(logical::CallExpr { + name: logical::CallName::CollEvery(SetQuantifier::All), + arguments: args, + }) + }), + }, + CallSpec { + input: vec![CallSpecArg::Named("distinct".into())], + output: Box::new(|args| { + logical::ValueExpr::Call(logical::CallExpr { + name: logical::CallName::CollEvery(SetQuantifier::Distinct), + arguments: args, + }) + }), + }, + ], + } +} + pub(crate) static FN_SYM_TAB: Lazy = Lazy::new(function_call_def); /// Function symbol table @@ -698,6 +768,8 @@ pub fn function_call_def() -> FnSymTab { function_call_def_coll_max(), function_call_def_coll_min(), function_call_def_coll_sum(), + function_call_def_coll_any(), + function_call_def_coll_every(), ] { assert!(!def.names.is_empty()); let primary = def.names[0]; diff --git a/partiql-logical-planner/src/lower.rs b/partiql-logical-planner/src/lower.rs index 7b110824..b1670d67 100644 --- a/partiql-logical-planner/src/lower.rs +++ b/partiql-logical-planner/src/lower.rs @@ -34,7 +34,7 @@ use partiql_ast_passes::error::{AstTransformError, AstTransformationError}; use partiql_catalog::Catalog; use partiql_extension_ion::decode::{IonDecoderBuilder, IonDecoderConfig}; use partiql_extension_ion::Encoding; -use partiql_logical::AggFunc::{AggAvg, AggCount, AggMax, AggMin, AggSum}; +use partiql_logical::AggFunc::{AggAny, AggAvg, AggCount, AggEvery, AggMax, AggMin, AggSum}; use std::sync::atomic::{AtomicU32, Ordering}; type FnvIndexMap = IndexMap; @@ -1172,6 +1172,18 @@ impl<'a, 'ast> Visitor<'ast> for AstToLogical<'a> { func: AggSum, setq, }, + "any" | "some" => AggregateExpression { + name: new_name, + expr: arg, + func: AggAny, + setq, + }, + "every" => AggregateExpression { + name: new_name, + expr: arg, + func: AggEvery, + setq, + }, _ => { // Include as an error but allow lowering to proceed for multiple error reporting self.errors diff --git a/partiql-logical/src/lib.rs b/partiql-logical/src/lib.rs index 30d8f86c..104f343d 100644 --- a/partiql-logical/src/lib.rs +++ b/partiql-logical/src/lib.rs @@ -359,6 +359,10 @@ pub enum AggFunc { AggMin, /// Represents SQL's `SUM` aggregation function AggSum, + /// Represents SQL's `ANY`/`SOME` aggregation function + AggAny, + /// Represents SQL's `EVERY` aggregation function + AggEvery, } /// Represents `GROUP BY` [, ] ... \[AS \] @@ -684,6 +688,8 @@ pub enum CallName { CollMax(SetQuantifier), CollMin(SetQuantifier), CollSum(SetQuantifier), + CollAny(SetQuantifier), + CollEvery(SetQuantifier), ByName(String), } diff --git a/partiql-parser/src/parse/parser_state.rs b/partiql-parser/src/parse/parser_state.rs index d114554a..833a2c2f 100644 --- a/partiql-parser/src/parse/parser_state.rs +++ b/partiql-parser/src/parse/parser_state.rs @@ -69,7 +69,8 @@ impl<'input> Default for ParserState<'input, NodeIdGenerator> { // TODO: currently needs to be manually kept in-sync with preprocessor's `built_in_aggs` // TODO: make extensible -const KNOWN_AGGREGATES: &str = "(?i:^count$)|(?i:^avg$)|(?i:^min$)|(?i:^max$)|(?i:^sum$)"; +const KNOWN_AGGREGATES: &str = + "(?i:^count$)|(?i:^avg$)|(?i:^min$)|(?i:^max$)|(?i:^sum$)|(?i:^any$)|(?i:^some$)|(?i:^every$)"; static KNOWN_AGGREGATE_PATTERN: Lazy = Lazy::new(|| Regex::new(KNOWN_AGGREGATES).unwrap()); impl<'input, I> ParserState<'input, I> diff --git a/partiql-parser/src/preprocessor.rs b/partiql-parser/src/preprocessor.rs index a76257bb..e23784f0 100644 --- a/partiql-parser/src/preprocessor.rs +++ b/partiql-parser/src/preprocessor.rs @@ -132,7 +132,7 @@ mod built_ins { pub(crate) fn built_in_aggs() -> FnExpr<'static> { FnExpr { // TODO: currently needs to be manually kept in-sync with parsers's `KNOWN_AGGREGATES` - fn_names: vec!["count", "avg", "min", "max", "sum"], + fn_names: vec!["count", "avg", "min", "max", "sum", "any", "some", "every"], #[rustfmt::skip] patterns: vec![ // e.g., count(all x) => count("all": x) diff --git a/partiql-types/src/lib.rs b/partiql-types/src/lib.rs index eb6c4a88..b67a8813 100644 --- a/partiql-types/src/lib.rs +++ b/partiql-types/src/lib.rs @@ -363,7 +363,7 @@ pub enum StructConstraint { #[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd)] #[allow(dead_code)] pub struct BagType { - element_type: Box, + pub element_type: Box, constraints: Vec, } @@ -385,7 +385,7 @@ impl BagType { #[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd)] #[allow(dead_code)] pub struct ArrayType { - element_type: Box, + pub element_type: Box, constraints: Vec, }