Skip to content

Commit

Permalink
Core: Fix incorrect searched CASE optimization (#14349)
Browse files Browse the repository at this point in the history
* Fix incorrect searched CASE optimization

There is an optimization for searched CASE where values are of boolean
type. It was converting the expression like

    CASE
        WHEN X THEN A
        WHEN Y THEN B
        ..
        [ ELSE D ]
    END

into

    (X AND A)
        OR (Y AND NOT X AND B)
        [ OR (NOT (X OR Y) AND D) ]

This had the following problems

- does not work for nullable conditions. If X is nullable, we cannot use
  NOT (X) to compliment it. We need to use `X IS DISTINCT FROM true`
- it does not work correctly when some conditions are nullable and other
  values are false. E.g. X=NULL, A=true, Y=NULL, B=true, D=false, the
  CASE should return false, but the boolean expression will simplify to
  `(NULL AND ..) OR (NULL AND ..) OR (false)` which is NULL, not false
  - thus we use `X` for truthness check of `X`, we need to test `X IS
    NOT DISTINCT FROM true`
- it did not work correctly when default D is missing, but conditions
  do not evaluate to NULL. CASE's result should be NULL but was false.

This commit fixes that optimization.

* Fix complexity comment
  • Loading branch information
findepi authored Jan 30, 2025
1 parent 07ee09a commit 11435de
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 29 deletions.
98 changes: 74 additions & 24 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1385,29 +1385,26 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
when_then_expr,
else_expr,
}) if !when_then_expr.is_empty()
&& when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number
&& when_then_expr.len() < 3 // The rewrite is O(n²) so limit to small number
&& info.is_boolean_type(&when_then_expr[0].1)? =>
{
// The disjunction of all the when predicates encountered so far
// String disjunction of all the when predicates encountered so far. Not nullable.
let mut filter_expr = lit(false);
// The disjunction of all the cases
let mut out_expr = lit(false);

for (when, then) in when_then_expr {
let case_expr = when
.as_ref()
.clone()
.and(filter_expr.clone().not())
.and(*then);
let when = is_exactly_true(*when, info)?;
let case_expr =
when.clone().and(filter_expr.clone().not()).and(*then);

out_expr = out_expr.or(case_expr);
filter_expr = filter_expr.or(*when);
filter_expr = filter_expr.or(when);
}

if let Some(else_expr) = else_expr {
let case_expr = filter_expr.not().and(*else_expr);
out_expr = out_expr.or(case_expr);
}
let else_expr = else_expr.map(|b| *b).unwrap_or_else(lit_bool_null);
let case_expr = filter_expr.not().and(else_expr);
out_expr = out_expr.or(case_expr);

// Do a first pass at simplification
out_expr.rewrite(self)?
Expand Down Expand Up @@ -1881,6 +1878,19 @@ fn inlist_except(mut l1: InList, l2: &InList) -> Result<Expr> {
Ok(Expr::InList(l1))
}

/// Returns expression testing a boolean `expr` for being exactly `true` (not `false` or NULL).
fn is_exactly_true(expr: Expr, info: &impl SimplifyInfo) -> Result<Expr> {
if !info.nullable(&expr)? {
Ok(expr)
} else {
Ok(Expr::BinaryExpr(BinaryExpr {
left: Box::new(expr),
op: Operator::IsNotDistinctFrom,
right: Box::new(lit(true)),
}))
}
}

#[cfg(test)]
mod tests {
use crate::simplify_expressions::SimplifyContext;
Expand Down Expand Up @@ -3272,12 +3282,12 @@ mod tests {
simplify(Expr::Case(Case::new(
None,
vec![(
Box::new(col("c2").not_eq(lit(false))),
Box::new(col("c2_non_null").not_eq(lit(false))),
Box::new(lit("ok").eq(lit("not_ok"))),
)],
Some(Box::new(col("c2").eq(lit(true)))),
Some(Box::new(col("c2_non_null").eq(lit(true)))),
))),
col("c2").not().and(col("c2")) // #1716
lit(false) // #1716
);

// CASE WHEN c2 != false THEN "ok" == "ok" ELSE c2
Expand All @@ -3292,12 +3302,12 @@ mod tests {
simplify(simplify(Expr::Case(Case::new(
None,
vec![(
Box::new(col("c2").not_eq(lit(false))),
Box::new(col("c2_non_null").not_eq(lit(false))),
Box::new(lit("ok").eq(lit("ok"))),
)],
Some(Box::new(col("c2").eq(lit(true)))),
Some(Box::new(col("c2_non_null").eq(lit(true)))),
)))),
col("c2")
col("c2_non_null")
);

// CASE WHEN ISNULL(c2) THEN true ELSE c2
Expand Down Expand Up @@ -3328,12 +3338,12 @@ mod tests {
simplify(simplify(Expr::Case(Case::new(
None,
vec![
(Box::new(col("c1")), Box::new(lit(true)),),
(Box::new(col("c2")), Box::new(lit(false)),),
(Box::new(col("c1_non_null")), Box::new(lit(true)),),
(Box::new(col("c2_non_null")), Box::new(lit(false)),),
],
Some(Box::new(lit(true))),
)))),
col("c1").or(col("c1").not().and(col("c2").not()))
col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
);

// CASE WHEN c1 then true WHEN c2 then true ELSE false
Expand All @@ -3347,13 +3357,53 @@ mod tests {
simplify(simplify(Expr::Case(Case::new(
None,
vec![
(Box::new(col("c1")), Box::new(lit(true)),),
(Box::new(col("c2")), Box::new(lit(false)),),
(Box::new(col("c1_non_null")), Box::new(lit(true)),),
(Box::new(col("c2_non_null")), Box::new(lit(false)),),
],
Some(Box::new(lit(true))),
)))),
col("c1").or(col("c1").not().and(col("c2").not()))
col("c1_non_null").or(col("c1_non_null").not().and(col("c2_non_null").not()))
);

// CASE WHEN c > 0 THEN true END AS c1
assert_eq!(
simplify(simplify(Expr::Case(Case::new(
None,
vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
None,
)))),
not_distinct_from(col("c3").gt(lit(0_i64)), lit(true)).or(distinct_from(
col("c3").gt(lit(0_i64)),
lit(true)
)
.and(lit_bool_null()))
);

// CASE WHEN c > 0 THEN true ELSE false END AS c1
assert_eq!(
simplify(simplify(Expr::Case(Case::new(
None,
vec![(Box::new(col("c3").gt(lit(0_i64))), Box::new(lit(true)))],
Some(Box::new(lit(false))),
)))),
not_distinct_from(col("c3").gt(lit(0_i64)), lit(true))
);
}

fn distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
Expr::BinaryExpr(BinaryExpr {
left: Box::new(left.into()),
op: Operator::IsDistinctFrom,
right: Box::new(right.into()),
})
}

fn not_distinct_from(left: impl Into<Expr>, right: impl Into<Expr>) -> Expr {
Expr::BinaryExpr(BinaryExpr {
left: Box::new(left.into()),
op: Operator::IsNotDistinctFrom,
right: Box::new(right.into()),
})
}

#[test]
Expand Down
20 changes: 15 additions & 5 deletions datafusion/sqllogictest/test_files/case.slt
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,22 @@ query B
select case when a=1 then false end from foo;
----
false
false
false
false
false
false
NULL
NULL
NULL
NULL
NULL

query IBB
SELECT c,
CASE WHEN c > 0 THEN true END AS c1,
CASE WHEN c > 0 THEN true ELSE false END AS c2
FROM (VALUES (1), (0), (-1), (NULL)) AS t(c)
----
1 true true
0 NULL false
-1 NULL false
NULL NULL false

statement ok
drop table foo

0 comments on commit 11435de

Please sign in to comment.