diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala index 0c70ef63ff285..aa95478778268 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, UnresolvedWithinGroup} +import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedWithinGroup} import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder} import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.types.PhysicalDataType @@ -42,6 +42,23 @@ case class Mode( this(child, 0, 0, Some(reverse)) } + override def checkInputDataTypes(): TypeCheckResult = { + if (UnsafeRowUtils.isBinaryStable(child.dataType) || + !child.dataType.existsRecursively(f => f.isInstanceOf[MapType] && + !UnsafeRowUtils.isBinaryStable(f))) { + /* + * The Mode class uses collation awareness logic to handle string data. + * Complex types with collated fields are not yet supported. + */ + // TODO: SPARK-48700: Mode expression for complex types (all collations) + super.checkInputDataTypes() + } else { + TypeCheckResult.TypeCheckFailure("The input to the function 'mode' includes" + + " a map with keys and/or values which are not binary-stable. This is not yet" + + s"supported by ${prettyName}.") + } + } + // Returns null for empty inputs override def nullable: Boolean = true @@ -71,7 +88,6 @@ case class Mode( buffer } - private def getCollationAwareBuffer( childDataType: DataType, buffer: OpenHashMap[AnyRef, Long]): Iterable[(AnyRef, Long)] = { @@ -84,10 +100,13 @@ case class Mode( private def getMyBuffer( childDataType: DataType): Option[AnyRef => _] = { + println(s"get my buffer. ${childDataType.getClass.getCanonicalName}") childDataType match { // Short-circuit if there is no collation. case _ if UnsafeRowUtils.isBinaryStable(child.dataType) => None - case c: StringType => Some(k => + case c: StringType => + println("Get Collation Key") + Some(k => CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], c.collationId)) case at: ArrayType => Some(k => recursivelyGetBufferForArrayType(at, k.asInstanceOf[ArrayData])) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 96bfcd5dd07f9..e20f97c5918b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -19,14 +19,12 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat - import scala.collection.immutable.Seq - import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException} import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow} import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.expressions.aggregate.Mode -import org.apache.spark.sql.internal.{SqlApiConf, SQLConf} +import org.apache.spark.sql.internal.{SQLConf, SqlApiConf} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1758,6 +1756,38 @@ class CollationSQLExpressionsSuite }) } + test("Support Mode.eval(buffer) with complex map types") { + case class UTF8StringModeTestCase[R]( + collationId: String, + bufferValues: Map[InternalRow, Long], + result: R) + + val bufferValuesUTF8String: Map[Any, Long] = Map( + UTF8String.fromString("a") -> 5L, + UTF8String.fromString("b") -> 4L, + UTF8String.fromString("B") -> 3L, + UTF8String.fromString("d") -> 2L, + UTF8String.fromString("e") -> 1L) + + val bufferValuesComplex = bufferValuesUTF8String.map{ + case (k, v) => (InternalRow.apply(org.apache.spark.sql.catalyst.util.MapData.apply(Map(k -> 1))), v) + } + val testCasesUTF8String = Seq( + UTF8StringModeTestCase("utf8_binary", bufferValuesComplex, "[a,a,a]"), + UTF8StringModeTestCase("UTF8_LCASE", bufferValuesComplex, "[b,b,b]"), + UTF8StringModeTestCase("unicode_ci", bufferValuesComplex, "[b,b,b]"), + UTF8StringModeTestCase("unicode", bufferValuesComplex, "[a,a,a]")) + + testCasesUTF8String.foreach(t => { + val buffer = new OpenHashMap[AnyRef, Long](5) + val myMode = Mode(child = Literal.create(null, StructType(Seq( + StructField("f1", MapType(StringType(t.collationId), IntegerType), true) + )))) + t.bufferValues.foreach { case (k, v) => buffer.update(k, v) } + assert(myMode.eval(buffer).toString.toLowerCase() == t.result.toLowerCase()) + }) + } + test("Support mode for string expression with collated strings in struct") { case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) val testCases = Seq( @@ -1809,6 +1839,38 @@ class CollationSQLExpressionsSuite }) } + test("Support mode for string expression with collated strings in " + + "recursively nested struct with map with collated keys") { + case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) + val testCases = Seq( + ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 1}"), + ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 1}"), + ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 1}"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 1}") + ) + testCases.foreach(t => { + val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => + (0L to numRepeats).map(_ => + s"named_struct('m1', " + + s"map(" + + s"collate(" + + s"'$elt', '${t.collationId}'" + + s"), " + + s"1))").mkString(",") + }.mkString(",") + + val tableName = s"t_${t.collationId}_mode_nested_struct1" + withTable(tableName) { + sql(s"CREATE TABLE ${tableName}(i STRUCT>) USING parquet") + sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) + val query = s"SELECT lower(cast(mode(i).m1 as string))" + + s" FROM ${tableName}" + checkAnswer(sql(query), Row(t.result)) + } + }) + } + test("Support mode for string expression with collated strings in array complex type") { case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) val testCases = Seq( @@ -1907,6 +1969,7 @@ class CollationSQLExpressionsSuite }) } + test("SPARK-48430: Map value extraction with collations") { for { collateKey <- Seq(true, false)