Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Support arrays_overlap function #1312

Merged
merged 1 commit into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ use datafusion::{
prelude::SessionContext,
};
use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr};
use datafusion_functions_nested::array_has::array_has_any_udf;
use datafusion_functions_nested::concat::ArrayAppend;
use datafusion_functions_nested::remove::array_remove_all_udf;
use datafusion_functions_nested::set_ops::array_intersect_udf;
Expand Down Expand Up @@ -818,6 +819,21 @@ impl PhysicalPlanner {
));
Ok(array_join_expr)
}
ExprStruct::ArraysOverlap(expr) => {
let left_array_expr =
self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?;
let right_array_expr =
self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?;
let args = vec![Arc::clone(&left_array_expr), right_array_expr];
let datafusion_array_has_any = array_has_any_udf();
let array_has_any_expr = Arc::new(ScalarFunctionExpr::new(
"array_has_any",
datafusion_array_has_any,
args,
DataType::Boolean,
));
Ok(array_has_any_expr)
}
expr => Err(ExecutionError::GeneralError(format!(
"Not implemented: {:?}",
expr
Expand Down
1 change: 1 addition & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ message Expr {
BinaryExpr array_remove = 61;
BinaryExpr array_intersect = 62;
ArrayJoin array_join = 63;
BinaryExpr arrays_overlap = 64;
}
}

Expand Down
16 changes: 16 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2428,6 +2428,22 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
withInfo(expr, "unsupported arguments for ArrayJoin", exprs: _*)
None
}
case ArraysOverlap(leftArrayExpr, rightArrayExpr) =>
if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) {
createBinaryExpr(
expr,
leftArrayExpr,
rightArrayExpr,
inputs,
binding,
(builder, binaryExpr) => builder.setArraysOverlap(binaryExpr))
} else {
withInfo(
expr,
s"$expr is not supported yet. To enable all incompatible casts, set " +
s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true")
None
}
case _ =>
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
None
Expand Down
21 changes: 21 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2701,4 +2701,25 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}
}

test("arrays_overlap") {
withSQLConf(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled, 10000)
spark.read.parquet(path.toString).createOrReplaceTempView("t1")
checkSparkAnswerAndOperator(sql(
"SELECT arrays_overlap(array(_2, _3, _4), array(_3, _4)) from t1 where _2 is not null"))
checkSparkAnswerAndOperator(sql(
"SELECT arrays_overlap(array('a', null, cast(_1 as string)), array('b', cast(_1 as string), cast(_2 as string))) from t1 where _1 is not null"))
checkSparkAnswerAndOperator(sql(
"SELECT arrays_overlap(array('a', null), array('b', null)) from t1 where _1 is not null"))
checkSparkAnswerAndOperator(spark.sql(
"SELECT arrays_overlap((CASE WHEN _2 =_3 THEN array(_6, _7) END), array(_6, _7)) FROM t1"));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think _2 = _3 may always be true. The problem with makeParquetFileAllTypes is that every for each row, each column contains the same integer value cast to the column's type, so it is not ideal for tests like this. We can improve as part of #1269

}
}
}
}

}
Loading