Skip to content

Commit

Permalink
Add SQL aggregates ANY, SOME, EVERY and their COLL_ versions
Browse files Browse the repository at this point in the history
  • Loading branch information
alancai98 committed Aug 4, 2023
1 parent 191b10e commit aa98fd9
Show file tree
Hide file tree
Showing 11 changed files with 324 additions and 53 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `partiql_ast_passes::static_typer` for type annotating the AST.
- Add ability to parse `ORDER BY`, `LIMIT`, `OFFSET` in children of set operators
- Add `OUTER` bag operator (`OUTER UNION`, `OUTER INTERSECT`, `OUTER EXCEPT`) implementation
- Implements the aggregation functions `ANY`, `SOME`, `EVERY` and their `COLL_` versions

### Fixes
- Fixes parsing of multiple consecutive path wildcards (e.g. `a[*][*][*]`), unpivot (e.g. `a.*.*.*`), and path expressions (e.g. `a[1 + 2][3 + 4][5 + 6]`)—previously these would not parse correctly.
Expand Down
18 changes: 16 additions & 2 deletions partiql-eval/src/eval/eval_expr_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ use std::ops::ControlFlow;
#[inline]
pub(crate) fn subsumes(typ: &PartiqlType, value: &Value) -> bool {
match (typ.kind(), value) {
(_, Value::Null) => true,
(_, Value::Missing) => true,
(TypeKind::Any, _) => true,
(TypeKind::AnyOf(anyof), val) => anyof.types().any(|typ| subsumes(typ, val)),
(TypeKind::Null, Value::Null) => true,
Expand All @@ -36,10 +38,22 @@ pub(crate) fn subsumes(typ: &PartiqlType, value: &Value) -> bool {
Value::String(_),
) => true,
(TypeKind::Struct(_), Value::Tuple(_)) => true,
(TypeKind::Bag(_), Value::Bag(_)) => true,
(TypeKind::Bag(b_type), Value::Bag(b_values)) => {
let bag_element_type = b_type.element_type.as_ref();
let mut b_values = b_values.iter();
b_values.all(|b_value| {
println!("b_value: {:?}", b_value);
println!("bag_element_type: {:?}", bag_element_type);
subsumes(bag_element_type, b_value)
})
}
(TypeKind::DateTime, Value::DateTime(_)) => true,

(TypeKind::Array(_), Value::List(_)) => true,
(TypeKind::Array(a_type), Value::List(l_values)) => {
let array_element_type = a_type.element_type.as_ref();
let mut l_values = l_values.iter();
l_values.all(|l_value| subsumes(array_element_type, l_value))
}
_ => false,
}
}
Expand Down
147 changes: 113 additions & 34 deletions partiql-eval/src/eval/evaluable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ pub(crate) enum AggFunc {
Max(Max),
Min(Min),
Sum(Sum),
Any(Any),
Every(Every),
}

impl AggregateFunction for AggFunc {
Expand All @@ -361,6 +363,8 @@ impl AggregateFunction for AggFunc {
AggFunc::Max(v) => v.next_value(input_value, group),
AggFunc::Min(v) => v.next_value(input_value, group),
AggFunc::Sum(v) => v.next_value(input_value, group),
AggFunc::Any(v) => v.next_value(input_value, group),
AggFunc::Every(v) => v.next_value(input_value, group),
}
}

Expand All @@ -371,6 +375,8 @@ impl AggregateFunction for AggFunc {
AggFunc::Max(v) => v.compute(group),
AggFunc::Min(v) => v.compute(group),
AggFunc::Sum(v) => v.compute(group),
AggFunc::Any(v) => v.compute(group),
AggFunc::Every(v) => v.compute(group),
}
}
}
Expand Down Expand Up @@ -454,7 +460,7 @@ impl Avg {

impl AggregateFunction for Avg {
fn next_value(&mut self, input_value: &Value, group: &Tuple) {
if !input_value.is_absent() && self.aggregator.filter_value(input_value.clone(), group) {
if input_value.is_present() && self.aggregator.filter_value(input_value.clone(), group) {
match self.avgs.get_mut(group) {
None => {
self.avgs.insert(group.clone(), (1, input_value.clone()));
Expand All @@ -468,12 +474,9 @@ impl AggregateFunction for Avg {
}

fn compute(&self, group: &Tuple) -> Result<Value, EvaluationError> {
match self.avgs.get(group) {
None => Err(EvaluationError::IllegalState(
"Expect group to exist in avgs".to_string(),
)),
Some((0, _)) => Ok(Null),
Some((c, s)) => Ok(s / &Value::from(rust_decimal::Decimal::from(*c))),
match self.avgs.get(group).unwrap_or(&(0, Null)) {
(0, _) => Ok(Null),
(c, s) => Ok(s / &Value::from(rust_decimal::Decimal::from(*c))),
}
}
}
Expand Down Expand Up @@ -503,7 +506,7 @@ impl Count {

impl AggregateFunction for Count {
fn next_value(&mut self, input_value: &Value, group: &Tuple) {
if !input_value.is_absent() && self.aggregator.filter_value(input_value.clone(), group) {
if input_value.is_present() && self.aggregator.filter_value(input_value.clone(), group) {
match self.counts.get_mut(group) {
None => {
self.counts.insert(group.clone(), 1);
Expand All @@ -516,12 +519,7 @@ impl AggregateFunction for Count {
}

fn compute(&self, group: &Tuple) -> Result<Value, EvaluationError> {
match self.counts.get(group) {
None => Err(EvaluationError::IllegalState(
"Expect group to exist in counts".to_string(),
)),
Some(val) => Ok(Value::from(val)),
}
Ok(Value::from(self.counts.get(group).unwrap_or(&0)))
}
}

Expand Down Expand Up @@ -550,7 +548,7 @@ impl Max {

impl AggregateFunction for Max {
fn next_value(&mut self, input_value: &Value, group: &Tuple) {
if !input_value.is_absent() && self.aggregator.filter_value(input_value.clone(), group) {
if input_value.is_present() && self.aggregator.filter_value(input_value.clone(), group) {
match self.maxes.get_mut(group) {
None => {
self.maxes.insert(group.clone(), input_value.clone());
Expand All @@ -563,12 +561,7 @@ impl AggregateFunction for Max {
}

fn compute(&self, group: &Tuple) -> Result<Value, EvaluationError> {
match self.maxes.get(group) {
None => Err(EvaluationError::IllegalState(
"Expect group to exist in maxes".to_string(),
)),
Some(val) => Ok(val.clone()),
}
Ok(self.maxes.get(group).unwrap_or(&Null).clone())
}
}

Expand Down Expand Up @@ -597,7 +590,7 @@ impl Min {

impl AggregateFunction for Min {
fn next_value(&mut self, input_value: &Value, group: &Tuple) {
if !input_value.is_absent() && self.aggregator.filter_value(input_value.clone(), group) {
if input_value.is_present() && self.aggregator.filter_value(input_value.clone(), group) {
match self.mins.get_mut(group) {
None => {
self.mins.insert(group.clone(), input_value.clone());
Expand All @@ -610,12 +603,7 @@ impl AggregateFunction for Min {
}

fn compute(&self, group: &Tuple) -> Result<Value, EvaluationError> {
match self.mins.get(group) {
None => Err(EvaluationError::IllegalState(
"Expect group to exist in mins".to_string(),
)),
Some(val) => Ok(val.clone()),
}
Ok(self.mins.get(group).unwrap_or(&Null).clone())
}
}

Expand Down Expand Up @@ -644,7 +632,7 @@ impl Sum {

impl AggregateFunction for Sum {
fn next_value(&mut self, input_value: &Value, group: &Tuple) {
if !input_value.is_absent() && self.aggregator.filter_value(input_value.clone(), group) {
if input_value.is_present() && self.aggregator.filter_value(input_value.clone(), group) {
match self.sums.get_mut(group) {
None => {
self.sums.insert(group.clone(), input_value.clone());
Expand All @@ -657,13 +645,104 @@ impl AggregateFunction for Sum {
}

fn compute(&self, group: &Tuple) -> Result<Value, EvaluationError> {
match self.sums.get(group) {
None => Err(EvaluationError::IllegalState(
"Expect group to exist in sums".to_string(),
)),
Some(val) => Ok(val.clone()),
Ok(self.sums.get(group).unwrap_or(&Null).clone())
}
}

/// Represents SQL's `ANY`/`SOME` aggregation function
#[derive(Debug)]
pub(crate) struct Any {
anys: HashMap<Tuple, Value>,
aggregator: AggFilterFn,
}

impl Any {
pub(crate) fn new_distinct() -> Self {
Any {
anys: HashMap::new(),
aggregator: AggFilterFn::Distinct(AggFilterDistinct::new()),
}
}

pub(crate) fn new_all() -> Self {
Any {
anys: HashMap::new(),
aggregator: AggFilterFn::default(),
}
}
}

impl AggregateFunction for Any {
fn next_value(&mut self, input_value: &Value, group: &Tuple) {
if input_value.is_present() && self.aggregator.filter_value(input_value.clone(), group) {
match self.anys.get_mut(group) {
None => {
match input_value {
Boolean(_) => self.anys.insert(group.clone(), input_value.clone()),
_ => self.anys.insert(group.clone(), Missing),
};
}
Some(acc) => {
*acc = match (acc.clone(), input_value) {
(Boolean(l), Value::Boolean(r)) => Value::Boolean(l || *r),
(_, _) => Missing,
};
}
}
}
}

fn compute(&self, group: &Tuple) -> Result<Value, EvaluationError> {
Ok(self.anys.get(group).unwrap_or(&Null).clone())
}
}

/// Represents SQL's `EVERY` aggregation function
#[derive(Debug)]
pub(crate) struct Every {
everys: HashMap<Tuple, Value>,
aggregator: AggFilterFn,
}

impl Every {
pub(crate) fn new_distinct() -> Self {
Every {
everys: HashMap::new(),
aggregator: AggFilterFn::Distinct(AggFilterDistinct::new()),
}
}

pub(crate) fn new_all() -> Self {
Every {
everys: HashMap::new(),
aggregator: AggFilterFn::default(),
}
}
}

impl AggregateFunction for Every {
fn next_value(&mut self, input_value: &Value, group: &Tuple) {
if input_value.is_present() && self.aggregator.filter_value(input_value.clone(), group) {
match self.everys.get_mut(group) {
None => {
match input_value {
Boolean(_) => self.everys.insert(group.clone(), input_value.clone()),
_ => self.everys.insert(group.clone(), Missing),
};
}
Some(acc) => {
*acc = match (acc.clone(), input_value) {
(Boolean(l), Value::Boolean(r)) => Value::Boolean(l && *r),
(_, _) => Missing,
};
}
}
}
}

fn compute(&self, group: &Tuple) -> Result<Value, EvaluationError> {
Ok(self.everys.get(group).unwrap_or(&Null).clone())
}
}

/// Represents an evaluation `GROUP BY` operator. For `GROUP BY` operational semantics, see section
Expand Down
Loading

0 comments on commit aa98fd9

Please sign in to comment.