Skip to content

Commit

Permalink
tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
GideonPotok committed Jul 9, 2024
1 parent 4e57f06 commit f469c2a
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResul
import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData, UnsafeRowUtils}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, GenericArrayData}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, StringType}
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.OpenHashMap

Expand All @@ -50,18 +50,7 @@ case class Mode(
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)

override def checkInputDataTypes(): TypeCheckResult = {
if (UnsafeRowUtils.isBinaryStable(child.dataType) || child.dataType.isInstanceOf[StringType]) {
/*
* The Mode class uses collation awareness logic to handle string data.
* Complex types with collated fields are not yet supported.
*/
// TODO: SPARK-48700: Mode expression for complex types (all collations)
super.checkInputDataTypes()
} else {
TypeCheckResult.TypeCheckFailure("The input to the function 'mode' was" +
" a type of binary-unstable type that is " +
s"not currently supported by ${prettyName}.")
}
super.checkInputDataTypes()
}

override def prettyName: String = "mode"
Expand All @@ -86,6 +75,66 @@ case class Mode(
buffer
}

private def recursivelyGetBufferForArrayType(
a: ArrayType,
data: ArrayData): Seq[Any] = {
(0 until data.numElements()).map { i =>
data.get(i, a.elementType) match {
case k: UTF8String if a.elementType.isInstanceOf[StringType] &&
!a.elementType.asInstanceOf[StringType].supportsBinaryEquality
=> CollationFactory.getCollationKey(k, a.elementType.asInstanceOf[StringType].collationId)
case k if a.elementType.isInstanceOf[StructType] =>
recursivelyGetBufferForStructType(
k.asInstanceOf[InternalRow].toSeq(a.elementType.asInstanceOf[StructType]).zip(
a.elementType.asInstanceOf[StructType].fields))
case k if a.elementType.isInstanceOf[ArrayType] =>
recursivelyGetBufferForArrayType(
a.elementType.asInstanceOf[ArrayType],
k.asInstanceOf[ArrayData])
case k => k
}
}
}

private def getBufferForComplexType(
buffer: OpenHashMap[AnyRef, Long],
d: DataType): Iterable[(AnyRef, Long)] = {
buffer.groupMapReduce {
case (key: InternalRow, _) if d.isInstanceOf[StructType] =>
recursivelyGetBufferForStructType(key.toSeq(d.asInstanceOf[StructType])
.zip(d.asInstanceOf[StructType].fields))
case (key: ArrayData, _) if d.isInstanceOf[ArrayType] =>
recursivelyGetBufferForArrayType(d.asInstanceOf[ArrayType], key)
}(x => x)((x, y) => (x._1, x._2 + y._2)).values
}

private def recursivelyGetBufferForStructType(
tuples: Seq[(Any, StructField)]): Seq[Any] = {
tuples.map {
case (k: String, field) if tuples.exists(f => f._2.dataType.isInstanceOf[StringType] &&
!f._2.dataType.asInstanceOf[StringType].supportsBinaryEquality &&
f._2.name == field.name) =>
CollationFactory.getCollationKey(UTF8String.fromString(k),
field.dataType.asInstanceOf[StringType].collationId)
case (k: UTF8String, field) if tuples.exists(f =>
f._2.dataType.isInstanceOf[StringType] &&
!f._2.dataType.asInstanceOf[StringType].supportsBinaryEquality &&
f._2.name == field.name) =>
CollationFactory.getCollationKey(k, field.dataType.asInstanceOf[StringType].collationId)
case (k, field: StructField) if tuples.exists(f => f._2.dataType.isInstanceOf[StructType] &&
f._2.name == field.name) =>
recursivelyGetBufferForStructType (
k.asInstanceOf[InternalRow].toSeq(field.dataType.asInstanceOf[StructType]).zip(
field.dataType.asInstanceOf[StructType].fields))
case (k, structField: StructField) if structField.dataType.isInstanceOf[ArrayType] &&
structField.dataType.asInstanceOf[ArrayType].existsRecursively(s =>
s.isInstanceOf[StringType] && !s.asInstanceOf[StringType].supportsBinaryEquality)
=> recursivelyGetBufferForArrayType(
structField.dataType.asInstanceOf[ArrayType],
k.asInstanceOf[ArrayData])
case (k, _) => k
}
}
override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
if (buffer.isEmpty) {
return null
Expand All @@ -106,11 +155,13 @@ case class Mode(
val collationAwareBuffer = child.dataType match {
case c: StringType if
!CollationFactory.fetchCollation(c.collationId).supportsBinaryEquality =>
val collationId = c.collationId
val modeMap = buffer.toSeq.groupMapReduce {
case (k, _) => CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], collationId)
buffer.toSeq.groupMapReduce {
case (k, _) => CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], c.collationId)
}(x => x)((x, y) => (x._1, x._2 + y._2)).values
modeMap
case _: ArrayType | _ : StructType if child.dataType.existsRecursively(s =>
s.isInstanceOf[StringType] && !CollationFactory.fetchCollation(
s.asInstanceOf[StringType].collationId).supportsBinaryEquality) =>
getBufferForComplexType(buffer, child.dataType)
case _ => buffer
}
reverseOpt.map { reverse =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.text.SimpleDateFormat
import scala.collection.immutable.Seq

import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow}
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.aggregate.Mode
import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
Expand Down Expand Up @@ -1710,6 +1710,40 @@ class CollationSQLExpressionsSuite
})
}

test("Support Mode.eval(buffer) with complex types") {
case class UTF8StringModeTestCase[R](
collationId: String,
bufferValues: Map[InternalRow, Long],
result: R)

val bufferValuesUTF8String: Map[Any, Long] = Map(
UTF8String.fromString("a") -> 5L,
UTF8String.fromString("b") -> 4L,
UTF8String.fromString("B") -> 3L,
UTF8String.fromString("d") -> 2L,
UTF8String.fromString("e") -> 1L)

val bufferValuesComplex = bufferValuesUTF8String.map{
case (k, v) => (InternalRow.fromSeq(Seq(k, k, k)), v)
}
val testCasesUTF8String = Seq(
UTF8StringModeTestCase("utf8_binary", bufferValuesComplex, "[a,a,a]"),
UTF8StringModeTestCase("UTF8_LCASE", bufferValuesComplex, "[b,b,b]"),
UTF8StringModeTestCase("unicode_ci", bufferValuesComplex, "[b,b,b]"),
UTF8StringModeTestCase("unicode", bufferValuesComplex, "[a,a,a]"))

testCasesUTF8String.foreach(t => {
val buffer = new OpenHashMap[AnyRef, Long](5)
val myMode = Mode(child = Literal.create(null, StructType(Seq(
StructField("f1", StringType(t.collationId), true),
StructField("f2", StringType(t.collationId), true),
StructField("f3", StringType(t.collationId), true)
))))
t.bufferValues.foreach { case (k, v) => buffer.update(k, v) }
assert(myMode.eval(buffer).toString.toLowerCase() == t.result.toLowerCase())
})
}

test("Support mode for string expression with collated strings in struct") {
case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R)
val testCases = Seq(
Expand All @@ -1718,6 +1752,7 @@ class CollationSQLExpressionsSuite
ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
)

testCases.foreach(t => {
val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
(0L to numRepeats).map(_ => s"named_struct('f1'," +
Expand All @@ -1730,33 +1765,7 @@ class CollationSQLExpressionsSuite
t.collationId + ", f2: INT>) USING parquet")
sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
val query = s"SELECT lower(mode(i).f1) FROM ${tableName}"
if(t.collationId == "UTF8_LCASE" ||
t.collationId == "unicode_ci" ||
t.collationId == "unicode") {
// Cannot resolve "mode(i)" due to data type mismatch:
// Input to function mode was a complex type with strings collated on non-binary
// collations, which is not yet supported.. SQLSTATE: 42K09; line 1 pos 13;
val params = Seq(("sqlExpr", "\"mode(i)\""),
("msg", "The input to the function 'mode'" +
" was a type of binary-unstable type that is not currently supported by mode."),
("hint", "")).toMap
checkError(
exception = intercept[AnalysisException] {
sql(query)
},
errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = params,
queryContext = Array(
ExpectedContext(objectType = "",
objectName = "",
startIndex = 13,
stopIndex = 19,
fragment = "mode(i)")
)
)
} else {
checkAnswer(sql(query), Row(t.result))
}
checkAnswer(sql(query), Row(t.result))
}
})
}
Expand All @@ -1775,44 +1784,90 @@ class CollationSQLExpressionsSuite
s"named_struct('f2', collate('$elt', '${t.collationId}')), 'f3', 1)").mkString(",")
}.mkString(",")

val tableName = s"t_${t.collationId}_mode_nested_struct"
val tableName = s"t_${t.collationId}_mode_nested_struct1"
withTable(tableName) {
sql(s"CREATE TABLE ${tableName}(i STRUCT<f1: STRUCT<f2: STRING COLLATE " +
t.collationId + ">, f3: INT>) USING parquet")
sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
val query = s"SELECT lower(mode(i).f1.f2) FROM ${tableName}"
if(t.collationId == "UTF8_LCASE" ||
t.collationId == "unicode_ci" ||
t.collationId == "unicode") {
// Cannot resolve "mode(i)" due to data type mismatch:
// Input to function mode was a complex type with strings collated on non-binary
// collations, which is not yet supported.. SQLSTATE: 42K09; line 1 pos 13;
val params = Seq(("sqlExpr", "\"mode(i)\""),
("msg", "The input to the function 'mode' " +
"was a type of binary-unstable type that is not currently supported by mode."),
("hint", "")).toMap
checkError(
exception = intercept[AnalysisException] {
sql(query)
},
errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = params,
queryContext = Array(
ExpectedContext(objectType = "",
objectName = "",
startIndex = 13,
stopIndex = 19,
fragment = "mode(i)")
)
)
} else {
checkAnswer(sql(query), Row(t.result))
}
checkAnswer(sql(query), Row(t.result))
}
})
}

test("Support mode for string expression with collated strings in array complex type") {
case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R)
val testCases = Seq(
ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
)
testCases.foreach(t => {
val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
(0L to numRepeats).map(_ => s"array(named_struct('f2', " +
s"collate('$elt', '${t.collationId}'), 'f3', 1))").mkString(",")
}.mkString(",")

val tableName = s"t_${t.collationId}_mode_nested_struct2"
withTable(tableName) {
sql(s"CREATE TABLE ${tableName}(" +
s"i ARRAY< STRUCT<f2: STRING COLLATE ${t.collationId}, f3: INT>>)" +
s" USING parquet")
sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
val query = s"SELECT lower(element_at(mode(i).f2, 1)) FROM ${tableName}"
checkAnswer(sql(query), Row(t.result))
}
})
}

test("Support mode for string expression with collated strings in 3D array type") {
case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R)
val testCases = Seq(
ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
)
testCases.foreach(t => {
val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
(0L to numRepeats).map(_ =>
s"array(" +
s"array(" +
s"array(" +
s"collate('$elt', '${t.collationId}')" +
s")" +
s")" +
s")").mkString(",")
}.mkString(",")

val tableName = s"t_${t.collationId}_mode_nested_3d_array"
withTable(tableName) {
sql(s"CREATE TABLE ${tableName}(" +
s"i ARRAY<" +
s"ARRAY<" +
s"ARRAY<" +
s"STRING COLLATE ${t.collationId}" +
s">" +
s">" +
s">)" +
s" USING parquet")
sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
val query = s"SELECT lower(" +
s"element_at(" +
s"element_at(" +
s"element_at(" +
s"mode(i)," +
s" 1)," +
s" 1)," +
s" 1)" +
s") FROM ${tableName}"
checkAnswer(sql(query), Row(t.result))
}
})
}

test("Support mode for string expression with collated complex type - Highly nested") {
case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R)
val testCases = Seq(
ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
Expand All @@ -1826,36 +1881,14 @@ class CollationSQLExpressionsSuite
s"array(collate('$elt', '${t.collationId}'))), 'f3', 1))").mkString(",")
}.mkString(",")

val tableName = s"t_${t.collationId}_mode_nested_struct"
val tableName = s"t_${t.collationId}_mode_highly_nested_struct"
withTable(tableName) {
sql(s"CREATE TABLE ${tableName}(" +
s"i ARRAY<STRUCT<s1: STRUCT<a2: ARRAY<STRING COLLATE ${t.collationId}>>, f3: INT>>)" +
s" USING parquet")
sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
val query = s"SELECT lower(element_at(element_at(mode(i), 1).s1.a2, 1)) FROM ${tableName}"
if(t.collationId == "UTF8_LCASE" ||
t.collationId == "unicode_ci" || t.collationId == "unicode") {
val params = Seq(("sqlExpr", "\"mode(i)\""),
("msg", "The input to the function 'mode' was a type" +
" of binary-unstable type that is not currently supported by mode."),
("hint", "")).toMap
checkError(
exception = intercept[AnalysisException] {
sql(query)
},
errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = params,
queryContext = Array(
ExpectedContext(objectType = "",
objectName = "",
startIndex = 35,
stopIndex = 41,
fragment = "mode(i)")
)
)
} else {
checkAnswer(sql(query), Row(t.result))
}
checkAnswer(sql(query), Row(t.result))
}
})
}
Expand Down

0 comments on commit f469c2a

Please sign in to comment.