Skip to content

Commit

Permalink
hello
Browse files Browse the repository at this point in the history
  • Loading branch information
GideonPotok committed Jul 30, 2024
1 parent ea7de85 commit 44b78f9
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, UnresolvedWithinGroup}
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedWithinGroup}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.types.PhysicalDataType
Expand All @@ -42,6 +42,23 @@ case class Mode(
this(child, 0, 0, Some(reverse))
}

override def checkInputDataTypes(): TypeCheckResult = {
if (UnsafeRowUtils.isBinaryStable(child.dataType) ||
!child.dataType.existsRecursively(f => f.isInstanceOf[MapType] &&
!UnsafeRowUtils.isBinaryStable(f))) {
/*
* The Mode class uses collation awareness logic to handle string data.
* Complex types with collated fields are not yet supported.
*/
// TODO: SPARK-48700: Mode expression for complex types (all collations)
super.checkInputDataTypes()
} else {
TypeCheckResult.TypeCheckFailure("The input to the function 'mode' includes" +
" a map with keys and/or values which are not binary-stable. This is not yet" +
s"supported by ${prettyName}.")
}
}

// Returns null for empty inputs
override def nullable: Boolean = true

Expand Down Expand Up @@ -71,7 +88,6 @@ case class Mode(
buffer
}


private def getCollationAwareBuffer(
childDataType: DataType,
buffer: OpenHashMap[AnyRef, Long]): Iterable[(AnyRef, Long)] = {
Expand All @@ -84,10 +100,13 @@ case class Mode(

private def getMyBuffer(
childDataType: DataType): Option[AnyRef => _] = {
println(s"get my buffer. ${childDataType.getClass.getCanonicalName}")
childDataType match {
// Short-circuit if there is no collation.
case _ if UnsafeRowUtils.isBinaryStable(child.dataType) => None
case c: StringType => Some(k =>
case c: StringType =>
println("Get Collation Key")
Some(k =>
CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], c.collationId))
case at: ArrayType => Some(k =>
recursivelyGetBufferForArrayType(at, k.asInstanceOf[ArrayData]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@ package org.apache.spark.sql

import java.sql.{Date, Timestamp}
import java.text.SimpleDateFormat

import scala.collection.immutable.Seq

import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow}
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.aggregate.Mode
import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
import org.apache.spark.sql.internal.{SQLConf, SqlApiConf}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -1758,6 +1756,38 @@ class CollationSQLExpressionsSuite
})
}

test("Support Mode.eval(buffer) with complex map types") {
case class UTF8StringModeTestCase[R](
collationId: String,
bufferValues: Map[InternalRow, Long],
result: R)

val bufferValuesUTF8String: Map[Any, Long] = Map(
UTF8String.fromString("a") -> 5L,
UTF8String.fromString("b") -> 4L,
UTF8String.fromString("B") -> 3L,
UTF8String.fromString("d") -> 2L,
UTF8String.fromString("e") -> 1L)

val bufferValuesComplex = bufferValuesUTF8String.map{
case (k, v) => (InternalRow.apply(org.apache.spark.sql.catalyst.util.MapData.apply(Map(k -> 1))), v)
}
val testCasesUTF8String = Seq(
UTF8StringModeTestCase("utf8_binary", bufferValuesComplex, "[a,a,a]"),
UTF8StringModeTestCase("UTF8_LCASE", bufferValuesComplex, "[b,b,b]"),
UTF8StringModeTestCase("unicode_ci", bufferValuesComplex, "[b,b,b]"),
UTF8StringModeTestCase("unicode", bufferValuesComplex, "[a,a,a]"))

testCasesUTF8String.foreach(t => {
val buffer = new OpenHashMap[AnyRef, Long](5)
val myMode = Mode(child = Literal.create(null, StructType(Seq(
StructField("f1", MapType(StringType(t.collationId), IntegerType), true)
))))
t.bufferValues.foreach { case (k, v) => buffer.update(k, v) }
assert(myMode.eval(buffer).toString.toLowerCase() == t.result.toLowerCase())
})
}

test("Support mode for string expression with collated strings in struct") {
case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R)
val testCases = Seq(
Expand Down Expand Up @@ -1809,6 +1839,38 @@ class CollationSQLExpressionsSuite
})
}

test("Support mode for string expression with collated strings in " +
"recursively nested struct with map with collated keys") {
case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R)
val testCases = Seq(
ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 1}"),
ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 1}"),
ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 1}"),
ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 1}")
)
testCases.foreach(t => {
val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
(0L to numRepeats).map(_ =>
s"named_struct('m1', " +
s"map(" +
s"collate(" +
s"'$elt', '${t.collationId}'" +
s"), " +
s"1))").mkString(",")
}.mkString(",")

val tableName = s"t_${t.collationId}_mode_nested_struct1"
withTable(tableName) {
sql(s"CREATE TABLE ${tableName}(i STRUCT<m1: MAP<STRING COLLATE " +
t.collationId + ", INT>>) USING parquet")
sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
val query = s"SELECT lower(cast(mode(i).m1 as string))" +
s" FROM ${tableName}"
checkAnswer(sql(query), Row(t.result))
}
})
}

test("Support mode for string expression with collated strings in array complex type") {
case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R)
val testCases = Seq(
Expand Down Expand Up @@ -1907,6 +1969,7 @@ class CollationSQLExpressionsSuite
})
}


test("SPARK-48430: Map value extraction with collations") {
for {
collateKey <- Seq(true, false)
Expand Down

0 comments on commit 44b78f9

Please sign in to comment.