diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index c3d7ac749..0bb5ce4f0 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -3452,7 +3452,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim /** * Trait for providing serialization logic for expressions. */ -trait CometExpressionSerde { +trait CometExpressionSerde extends CometExprShim { /** * Convert a Spark expression into a protocol buffer representation that can be passed into @@ -3473,4 +3473,37 @@ trait CometExpressionSerde { expr: Expression, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] + + def isPrimitiveType(dt: DataType): Boolean = { + import DataTypes._ + dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => + true + case _ => false + } + } + + def isDecimalType(dt: DataType): Boolean = { + dt match { + case _: DecimalType => true + case _ => false + } + } + + def isTemporalType(dt: DataType): Boolean = { + import DataTypes._ + dt match { + case DateType | TimestampType => true + case t if isTimestampNTZType(t) => true + case _ => false + } + } + + def isStringOrBinaryType(dt: DataType): Boolean = { + import DataTypes._ + dt match { + case StringType | BinaryType => true + case _ => false + } + } } diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index 9058a641e..b45d7ca35 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -20,26 +20,57 @@ package org.apache.comet.serde import org.apache.spark.sql.catalyst.expressions.{ArrayRemove, Attribute, Expression} -import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType} import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.QueryPlanSerde.createBinaryExpr -import org.apache.comet.shims.CometExprShim -object CometArrayRemove extends CometExpressionSerde with CometExprShim { +object CometArrayContains extends CometExpressionSerde { /** Exposed for unit testing */ def isTypeSupported(dt: DataType): Boolean = { - import DataTypes._ + if (isPrimitiveType(dt) || isDecimalType(dt) || isTemporalType(dt) || isStringOrBinaryType( + dt)) { + return true + } + dt match { + case ArrayType(elementType, _) => isTypeSupported(elementType) + case _ => false + } + } + + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val ar = expr.asInstanceOf[ArrayRemove] + val inputTypes: Set[DataType] = ar.children.map(_.dataType).toSet + for (dt <- inputTypes) { + if (!isTypeSupported(dt)) { + withInfo(expr, s"data type not supported: $dt") + return None + } + } + createBinaryExpr( + expr, + expr.children(0), + expr.children(1), + inputs, + binding, + (builder, binaryExpr) => builder.setArrayContains(binaryExpr)) + } +} + +object CometArrayRemove extends CometExpressionSerde { + + /** Exposed for unit testing */ + def isTypeSupported(dt: DataType): Boolean = { + if (isPrimitiveType(dt) || isDecimalType(dt) || isTemporalType(dt) || isStringOrBinaryType( + dt)) { + return true + } dt match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | - _: DecimalType | DateType | TimestampType | StringType | BinaryType => - true - case t if isTimestampNTZType(t) => true case ArrayType(elementType, _) => isTypeSupported(elementType) - case _: StructType => - // https://github.com/apache/datafusion-comet/issues/1307 - false case _ => false } } diff --git a/spark/src/main/scala/org/apache/comet/testing/ParquetGenerator.scala b/spark/src/main/scala/org/apache/comet/testing/ParquetGenerator.scala index f209cc4c9..ee2773579 100644 --- a/spark/src/main/scala/org/apache/comet/testing/ParquetGenerator.scala +++ b/spark/src/main/scala/org/apache/comet/testing/ParquetGenerator.scala @@ -212,8 +212,8 @@ object ParquetGenerator { } case class DataGenOptions( - allowNull: Boolean, - generateNegativeZero: Boolean, - generateArray: Boolean, - generateStruct: Boolean, - generateMap: Boolean) + allowNull: Boolean = true, + generateNegativeZero: Boolean = true, + generateArray: Boolean = true, + generateStruct: Boolean = true, + generateMap: Boolean = true) diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 5727f9f90..609fff9fb 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -19,18 +19,55 @@ package org.apache.comet +import java.io.File + import scala.collection.immutable.HashSet import scala.util.Random import org.apache.hadoop.fs.Path -import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.{CometTestBase, DataFrame} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.types.StructType +import org.apache.comet.serde.{CometArrayContains, CometArrayRemove} import org.apache.comet.testing.{DataGenOptions, ParquetGenerator} class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { + // TODO enable complex types once native scan supports them + private val dataGenOptions = + DataGenOptions(generateArray = false, generateStruct = false, generateMap = false) + + test("array_contains") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = false, n = 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1"); + checkSparkAnswerAndOperator( + spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1")) + checkSparkAnswerAndOperator( + spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); + } + } + + test("array_contains - test all types") { + withTempDir { dir => + val df = generateTestData(dir, dataGenOptions) + df.createOrReplaceTempView("t1") + // test with array of each column + for (field <- df.schema.fields) { + val fieldName = field.name + sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1") + .createOrReplaceTempView("t2") + val df = sql("SELECT array_contains(a, b) FROM t2") + if (CometArrayContains.isTypeSupported(field.dataType)) { + checkSparkAnswerAndOperator(df) + } else { + checkSparkAnswer(df) + } + } + } + } + test("array_remove - integer") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir => @@ -47,68 +84,20 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } - test("array_remove - test all types (native Parquet reader)") { + test("array_remove - test all types") { withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - val filename = path.toString - val random = new Random(42) - withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - ParquetGenerator.makeParquetFile( - random, - spark, - filename, - 100, - DataGenOptions( - allowNull = true, - generateNegativeZero = true, - generateArray = false, - generateStruct = false, - generateMap = false)) - } - val table = spark.read.parquet(filename) - table.createOrReplaceTempView("t1") + val df = generateTestData(dir, dataGenOptions) + df.createOrReplaceTempView("t1") // test with array of each column - for (fieldName <- table.schema.fieldNames) { + for (field <- df.schema.fields) { + val fieldName = field.name sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1") .createOrReplaceTempView("t2") val df = sql("SELECT array_remove(a, b) FROM t2") - checkSparkAnswerAndOperator(df) - } - } - } - - test("array_remove - test all types (convert from Parquet)") { - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - val filename = path.toString - val random = new Random(42) - withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - val options = DataGenOptions( - allowNull = true, - generateNegativeZero = true, - generateArray = true, - generateStruct = true, - generateMap = false) - ParquetGenerator.makeParquetFile(random, spark, filename, 100, options) - } - withSQLConf( - CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false", - CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true", - CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") { - val table = spark.read.parquet(filename) - table.createOrReplaceTempView("t1") - // test with array of each column - for (field <- table.schema.fields) { - val fieldName = field.name - sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1") - .createOrReplaceTempView("t2") - val df = sql("SELECT array_remove(a, b) FROM t2") - field.dataType match { - case _: StructType => - // skip due to https://github.com/apache/datafusion-comet/issues/1314 - case _ => - checkSparkAnswer(df) - } + if (CometArrayRemove.isTypeSupported(field.dataType)) { + checkSparkAnswerAndOperator(df) + } else { + checkSparkAnswer(df) } } } @@ -131,4 +120,14 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp checkExplainString = false) } } + + private def generateTestData(dir: File, options: DataGenOptions): DataFrame = { + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + ParquetGenerator.makeParquetFile(random, spark, filename, 100, options) + } + spark.read.parquet(filename) + } } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 99cf4bad4..eff28f7a0 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2657,18 +2657,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("array_contains") { - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllTypes(path, dictionaryEnabled = false, n = 10000) - spark.read.parquet(path.toString).createOrReplaceTempView("t1"); - checkSparkAnswerAndOperator( - spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1")) - checkSparkAnswerAndOperator( - spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1")); - } - } - test("array_intersect") { Seq(true, false).foreach { dictionaryEnabled => withTempDir { dir =>