diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 95926bfee..bef4d7c03 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -110,6 +110,7 @@ use datafusion_expr::{ WindowFunctionDefinition, }; use datafusion_functions_nested::array_has::ArrayHas; +use datafusion_functions_nested::repeat::array_repeat_udf; use datafusion_physical_expr::expressions::{Literal, StatsType}; use datafusion_physical_expr::window::WindowExpr; use datafusion_physical_expr::LexOrdering; @@ -776,6 +777,49 @@ impl PhysicalPlanner { Ok(Arc::new(case_expr)) } + ExprStruct::ArrayRepeat(expr) => { + let src_expr = + self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; + let count_expr = + self.create_expr(expr.right.as_ref().unwrap(), Arc::clone(&input_schema))?; + // Cast count_expr from Int32 to Int64 to support df count argument + let count_expr: Arc = + match count_expr.data_type(&Arc::clone(&input_schema))? { + DataType::Int32 => Arc::new(CastExpr::new( + count_expr, + DataType::Int64, + Some(CastOptions::default()), + )), + _ => count_expr, + }; + + let args = vec![Arc::clone(&src_expr), Arc::clone(&count_expr)]; + + let datafusion_array_repeat = array_repeat_udf(); + let data_types: Vec = vec![ + src_expr.data_type(&Arc::clone(&input_schema))?, + count_expr.data_type(&Arc::clone(&input_schema))?, + ]; + let return_type = datafusion_array_repeat.return_type(&data_types)?; + + let array_repeat_expr: Arc = Arc::new(ScalarFunctionExpr::new( + "array_repeat", + datafusion_array_repeat, + args, + return_type, + )); + + let is_null_expr: Arc = Arc::new(IsNullExpr::new(count_expr)); + let null_literal_expr: Arc = + Arc::new(Literal::new(ScalarValue::Null)); + + let case_expr = CaseExpr::try_new( + None, + vec![(is_null_expr, null_literal_expr)], + Some(array_repeat_expr), + )?; + Ok(Arc::new(case_expr)) + } ExprStruct::ArrayIntersect(expr) => { let left_expr = self.create_expr(expr.left.as_ref().unwrap(), Arc::clone(&input_schema))?; diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 83d6da7cb..25d22c5b9 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -88,6 +88,7 @@ message Expr { BinaryExpr array_remove = 61; BinaryExpr array_intersect = 62; ArrayJoin array_join = 63; + BinaryExpr array_repeat = 64; } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index c3d7ac749..4707574b4 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2396,6 +2396,14 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim inputs, binding, (builder, binaryExpr) => builder.setArrayIntersect(binaryExpr)) + case _ if expr.prettyName == "array_repeat" => + createBinaryExpr( + expr, + expr.children(0), + expr.children(1), + inputs, + binding, + (builder, binaryExpr) => builder.setArrayRepeat(binaryExpr)) case ArrayJoin(arrayExpr, delimiterExpr, nullReplacementExpr) => val arrayExprProto = exprToProto(arrayExpr, inputs, binding) val delimiterExprProto = exprToProto(delimiterExpr, inputs, binding) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 99cf4bad4..f7fe25de3 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2701,4 +2701,22 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + + test("array_repeat") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + + checkSparkAnswerAndOperator(sql("SELECT array_repeat(_2, _4) from t1")) + checkSparkAnswerAndOperator( + sql("SELECT array_repeat(_2, 5) from t1 where _2 is not null")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(_3, 2) from t1 where _3 is null")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(_3, _3) from t1 where _3 is null")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(cast(_3 as string), 2) from t1")) + checkSparkAnswerAndOperator(sql("SELECT array_repeat(array(_2, _3, _4), 2) from t1")) + } + } + } }