From 30b05b390d3aa734e5124feb338acd829a9a4941 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Thu, 23 Jan 2025 14:14:55 +0100 Subject: [PATCH] fix failing tests for variantagg --- .../sql/catalyst/analysis/ResolveDefaultStringTypes.scala | 5 +++-- .../catalyst/expressions/variant/variantExpressions.scala | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala index 70b2193aa2432..1f3818c1f383c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultStringTypes.scala @@ -156,10 +156,11 @@ object ResolveDefaultStringTypes extends Rule[LogicalPlan] { case expression if needsCast(expression) => expression.setTagValue(CAST_ADDED_TAG, ()) newType => { - if (expression.dataType.sameType(newType)) { + val replacedType = replaceDefaultStringType(expression.dataType, newType) + if (newType == StringType || replacedType.sameType(newType)) { expression } else { - Cast(expression, replaceDefaultStringType(expression.dataType, newType)) + Cast(expression, replacedType) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index f722329097bc0..528dbed02ff46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -858,12 +858,13 @@ case class SchemaOfVariantAgg( extends TypedImperativeAggregate[DataType] with ExpectsInputTypes with QueryErrorsBase - with DefaultStringProducingExpression with UnaryLike[Expression] { def this(child: Expression) = this(child, 0, 0) override def inputTypes: Seq[AbstractDataType] = Seq(VariantType) + override def dataType: DataType = SQLConf.get.defaultStringType + override def nullable: Boolean = false override def createAggregationBuffer(): DataType = NullType