Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: improve tests for array expressions #1339

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3452,7 +3452,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
/**
* Trait for providing serialization logic for expressions.
*/
trait CometExpressionSerde {
trait CometExpressionSerde extends CometExprShim {

/**
* Convert a Spark expression into a protocol buffer representation that can be passed into
Expand All @@ -3473,4 +3473,37 @@ trait CometExpressionSerde {
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr]

def isPrimitiveType(dt: DataType): Boolean = {
import DataTypes._
dt match {
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
true
case _ => false
}
}

def isDecimalType(dt: DataType): Boolean = {
dt match {
case _: DecimalType => true
case _ => false
}
}

def isTemporalType(dt: DataType): Boolean = {
import DataTypes._
dt match {
case DateType | TimestampType => true
case t if isTimestampNTZType(t) => true
case _ => false
}
}

def isStringOrBinaryType(dt: DataType): Boolean = {
import DataTypes._
dt match {
case StringType | BinaryType => true
case _ => false
}
}
}
53 changes: 42 additions & 11 deletions spark/src/main/scala/org/apache/comet/serde/arrays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,57 @@
package org.apache.comet.serde

import org.apache.spark.sql.catalyst.expressions.{ArrayRemove, Attribute, Expression}
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType}

import org.apache.comet.CometSparkSessionExtensions.withInfo
import org.apache.comet.serde.QueryPlanSerde.createBinaryExpr
import org.apache.comet.shims.CometExprShim

object CometArrayRemove extends CometExpressionSerde with CometExprShim {
object CometArrayContains extends CometExpressionSerde {

/** Exposed for unit testing */
def isTypeSupported(dt: DataType): Boolean = {
import DataTypes._
if (isPrimitiveType(dt) || isDecimalType(dt) || isTemporalType(dt) || isStringOrBinaryType(
dt)) {
return true
}
dt match {
case ArrayType(elementType, _) => isTypeSupported(elementType)
case _ => false
}
}

override def convert(
expr: Expression,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
val ar = expr.asInstanceOf[ArrayRemove]
val inputTypes: Set[DataType] = ar.children.map(_.dataType).toSet
for (dt <- inputTypes) {
if (!isTypeSupported(dt)) {
withInfo(expr, s"data type not supported: $dt")
return None
}
}
createBinaryExpr(
expr,
expr.children(0),
expr.children(1),
inputs,
binding,
(builder, binaryExpr) => builder.setArrayContains(binaryExpr))
}
}

object CometArrayRemove extends CometExpressionSerde {

/** Exposed for unit testing */
def isTypeSupported(dt: DataType): Boolean = {
if (isPrimitiveType(dt) || isDecimalType(dt) || isTemporalType(dt) || isStringOrBinaryType(
dt)) {
return true
}
dt match {
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType |
_: DecimalType | DateType | TimestampType | StringType | BinaryType =>
true
case t if isTimestampNTZType(t) => true
case ArrayType(elementType, _) => isTypeSupported(elementType)
case _: StructType =>
// https://github.com/apache/datafusion-comet/issues/1307
false
case _ => false
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ object ParquetGenerator {
}

case class DataGenOptions(
allowNull: Boolean,
generateNegativeZero: Boolean,
generateArray: Boolean,
generateStruct: Boolean,
generateMap: Boolean)
allowNull: Boolean = true,
generateNegativeZero: Boolean = true,
generateArray: Boolean = true,
generateStruct: Boolean = true,
generateMap: Boolean = true)
117 changes: 58 additions & 59 deletions spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,55 @@

package org.apache.comet

import java.io.File

import scala.collection.immutable.HashSet
import scala.util.Random

import org.apache.hadoop.fs.Path
import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.{CometTestBase, DataFrame}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.types.StructType

import org.apache.comet.serde.{CometArrayContains, CometArrayRemove}
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator}

class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {

// TODO enable complex types once native scan supports them
private val dataGenOptions =
DataGenOptions(generateArray = false, generateStruct = false, generateMap = false)

test("array_contains") {
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = false, n = 10000)
spark.read.parquet(path.toString).createOrReplaceTempView("t1");
checkSparkAnswerAndOperator(
spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1"))
checkSparkAnswerAndOperator(
spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
}
}

test("array_contains - test all types") {
withTempDir { dir =>
val df = generateTestData(dir, dataGenOptions)
df.createOrReplaceTempView("t1")
// test with array of each column
for (field <- df.schema.fields) {
val fieldName = field.name
sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1")
.createOrReplaceTempView("t2")
val df = sql("SELECT array_contains(a, b) FROM t2")
if (CometArrayContains.isTypeSupported(field.dataType)) {
checkSparkAnswerAndOperator(df)
} else {
checkSparkAnswer(df)
}
}
}
}

test("array_remove - integer") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
Expand All @@ -47,68 +84,20 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
}
}

test("array_remove - test all types (native Parquet reader)") {
test("array_remove - test all types") {
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
val filename = path.toString
val random = new Random(42)
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
ParquetGenerator.makeParquetFile(
random,
spark,
filename,
100,
DataGenOptions(
allowNull = true,
generateNegativeZero = true,
generateArray = false,
generateStruct = false,
generateMap = false))
}
val table = spark.read.parquet(filename)
table.createOrReplaceTempView("t1")
val df = generateTestData(dir, dataGenOptions)
df.createOrReplaceTempView("t1")
// test with array of each column
for (fieldName <- table.schema.fieldNames) {
for (field <- df.schema.fields) {
val fieldName = field.name
sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1")
.createOrReplaceTempView("t2")
val df = sql("SELECT array_remove(a, b) FROM t2")
checkSparkAnswerAndOperator(df)
}
}
}

test("array_remove - test all types (convert from Parquet)") {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This turned out not to be very useful, since we don't support arrays when using COMET_CONVERT_FROM_PARQUET_ENABLED

withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
val filename = path.toString
val random = new Random(42)
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
val options = DataGenOptions(
allowNull = true,
generateNegativeZero = true,
generateArray = true,
generateStruct = true,
generateMap = false)
ParquetGenerator.makeParquetFile(random, spark, filename, 100, options)
}
withSQLConf(
CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false",
CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true",
CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.key -> "true") {
val table = spark.read.parquet(filename)
table.createOrReplaceTempView("t1")
// test with array of each column
for (field <- table.schema.fields) {
val fieldName = field.name
sql(s"SELECT array($fieldName, $fieldName) as a, $fieldName as b FROM t1")
.createOrReplaceTempView("t2")
val df = sql("SELECT array_remove(a, b) FROM t2")
field.dataType match {
case _: StructType =>
// skip due to https://github.com/apache/datafusion-comet/issues/1314
case _ =>
checkSparkAnswer(df)
}
if (CometArrayRemove.isTypeSupported(field.dataType)) {
checkSparkAnswerAndOperator(df)
} else {
checkSparkAnswer(df)
}
}
}
Expand All @@ -131,4 +120,14 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
checkExplainString = false)
}
}

private def generateTestData(dir: File, options: DataGenOptions): DataFrame = {
val path = new Path(dir.toURI.toString, "test.parquet")
val filename = path.toString
val random = new Random(42)
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
ParquetGenerator.makeParquetFile(random, spark, filename, 100, options)
}
spark.read.parquet(filename)
}
}
12 changes: 0 additions & 12 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2657,18 +2657,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("array_contains") {
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled = false, n = 10000)
spark.read.parquet(path.toString).createOrReplaceTempView("t1");
checkSparkAnswerAndOperator(
spark.sql("SELECT array_contains(array(_2, _3, _4), _2) FROM t1"))
checkSparkAnswerAndOperator(
spark.sql("SELECT array_contains((CASE WHEN _2 =_3 THEN array(_4) END), _4) FROM t1"));
}
}

test("array_intersect") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
Expand Down
Loading