diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 09ec501311ade..ac23962f41ed3 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -2240,8 +2240,8 @@ case class Levenshtein( } override def inputTypes: Seq[AbstractDataType] = threshold match { - case Some(_) => Seq(StringType, StringType, IntegerType) - case _ => Seq(StringType, StringType) + case Some(_) => Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType) + case _ => Seq(StringTypeAnyCollation, StringTypeAnyCollation) } override def children: Seq[Expression] = threshold match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index db02946e3dfe5..31be149b9c9cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -645,6 +645,28 @@ class CollationStringExpressionsSuite }) } + test("Levenshtein string expression with collation") { + // Supported collations + case class LevenshteinTestCase( + left: String, right: String, collationName: String, threshold: Option[Int], result: Int + ) + val testCases = Seq( + LevenshteinTestCase("kitten", "sitTing", "UTF8_BINARY", None, result = 4), + LevenshteinTestCase("kitten", "sitTing", "UTF8_BINARY_LCASE", None, result = 4), + LevenshteinTestCase("kitten", "sitTing", "UNICODE", Some(3), result = -1), + LevenshteinTestCase("kitten", "sitTing", "UNICODE_CI", Some(3), result = -1) + ) + testCases.foreach(t => { + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> t.collationName) { + val th = if (t.threshold.isDefined) s", ${t.threshold.get}" else "" + val query = s"select levenshtein('${t.left}', '${t.right}'$th)" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) + } + }) + } + test("Support Left/Right/Substr with collation") { case class SubstringTestCase( method: String,