Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-207] fix issues found in scala unit tests (#356)
Browse files Browse the repository at this point in the history
* convert sql tests

* fix compilation error due to missing validity

* fix literal aggregation

* return null for sum and avg when all inputs are null

* fall back columnar like unsupported cases

* fix min_max segfault on DateType

* fix grouping_literal and count_literal

* fix bool in HashRelation

* fix max

* fallback AggregateExpresion (Count) with filter

* refine

* convert travis tests to native sql tests

* fallback full outer join

* support null literal

* ignore the difference of upper case and lower case in sort attribute names

* fix wrong validity in isnull

* fix compilation error due to undefined variable in local function

* fix travis: isnotnull(timestamp[us]) not supported in Gandiva
  • Loading branch information
rui-mo authored Jun 23, 2021
1 parent e9c86f7 commit 8fe1253
Show file tree
Hide file tree
Showing 25 changed files with 1,067 additions and 218 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ jobs:
mvn clean install -DskipTests -Dbuild_arrow=OFF
cd ..
mvn clean package -P full-scala-compiler -am -pl native-sql-engine/core -DskipTests -Dbuild_arrow=OFF
mvn test -P full-scala-compiler -DmembersOnlySuites=org.apache.spark.sql.travis -am -DfailIfNoTests=false -Dexec.skip=true -DargLine="-Dspark.test.home=/tmp/spark-3.0.0-bin-hadoop2.7" &> log-file.log
mvn test -P full-scala-compiler -DmembersOnlySuites=org.apache.spark.sql.nativesql -am -DfailIfNoTests=false -Dexec.skip=true -DargLine="-Dspark.test.home=/tmp/spark-3.0.0-bin-hadoop2.7" &> log-file.log
echo '#!/bin/bash' > grep.sh
echo "module_tested=0; module_should_test=8; tests_total=0; while read -r line; do num=\$(echo \"\$line\" | grep -o -E '[0-9]+'); tests_total=\$((tests_total+num)); done <<<\"\$(grep \"Total number of tests run:\" log-file.log)\"; succeed_total=0; while read -r line; do [[ \$line =~ [^0-9]*([0-9]+)\, ]]; num=\${BASH_REMATCH[1]}; succeed_total=\$((succeed_total+num)); let module_tested++; done <<<\"\$(grep \"succeeded\" log-file.log)\"; if test \$tests_total -eq \$succeed_total -a \$module_tested -eq \$module_should_test; then echo \"All unit tests succeed\"; else echo \"Unit tests failed\"; exit 1; fi" >> grep.sh
bash grep.sh
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ case class ColumnarGuardRule() extends Rule[SparkPlan] {
case plan: InMemoryTableScanExec =>
new ColumnarInMemoryTableScanExec(plan.attributes, plan.predicates, plan.relation)
case plan: ProjectExec =>
if(!enableColumnarProjFilter) return false
if (!enableColumnarProjFilter) return false
new ColumnarConditionProjectExec(null, plan.projectList, plan.child)
case plan: FilterExec =>
if (!enableColumnarProjFilter) return false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ case class ColumnarBroadcastHashJoinExec(
buildCheck()

def buildCheck(): Unit = {
joinType match {
case _: InnerLike =>
case LeftSemi | LeftOuter | RightOuter | LeftAnti =>
case j: ExistenceJoin =>
case _ =>
throw new UnsupportedOperationException(s"Join Type ${joinType} is not supported yet.")
}
// build check for condition
val conditionExpr: Expression = condition.orNull
if (conditionExpr != null) {
Expand All @@ -109,8 +116,6 @@ case class ColumnarBroadcastHashJoinExec(
for (attr <- buildPlan.output) {
try {
ConverterUtils.checkIfTypeSupported(attr.dataType)
//if (attr.dataType.isInstanceOf[DecimalType])
// throw new UnsupportedOperationException(s"Unsupported data type: ${attr.dataType}")
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ case class ColumnarShuffledHashJoinExec(
}

def buildCheck(): Unit = {
joinType match {
case _: InnerLike =>
case LeftSemi | LeftOuter | RightOuter | LeftAnti =>
case j: ExistenceJoin =>
case _ =>
throw new UnsupportedOperationException(s"Join Type ${joinType} is not supported yet.")
}
// build check for condition
val conditionExpr: Expression = condition.orNull
if (conditionExpr != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,13 @@ case class ColumnarSortMergeJoinExec(
}*/

def buildCheck(): Unit = {
joinType match {
case _: InnerLike =>
case LeftSemi | LeftOuter | RightOuter | LeftAnti =>
case j: ExistenceJoin =>
case _ =>
throw new UnsupportedOperationException(s"Join Type ${joinType} is not supported yet.")
}
// build check for condition
val conditionExpr: Expression = condition.orNull
if (conditionExpr != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,20 @@ class ColumnarLike(left: Expression, right: Expression, original: Expression)
extends Like(left: Expression, right: Expression)
with ColumnarExpression
with Logging {

buildCheck()

def buildCheck(): Unit = {
if (original.asInstanceOf[Like].escapeChar.toString.nonEmpty) {
throw new UnsupportedOperationException(
s"escapeChar is not supported in ColumnarLike")
}
if (!right.isInstanceOf[Literal]) {
throw new UnsupportedOperationException(
s"Gandiva 'like' function requires a literal as the second parameter.")
}
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
val (left_node, left_type): (TreeNode, ArrowType) =
left.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,19 @@ class ColumnarSorter(
var shuffle_elapse: Long = 0
var total_elapse: Long = 0
val inputBatchHolder = new ListBuffer[ColumnarBatch]()
var nextVector: FieldVector = null
var nextVector: FieldVector = _
var closed: Boolean = false
val resultSchema = StructType(
val resultSchema: StructType = StructType(
outputAttributes
.map(expr => {
val attr = ConverterUtils.getAttrFromExpr(expr)
StructField(s"${attr.name}", attr.dataType, true)
StructField(s"${attr.name.toLowerCase()}", attr.dataType, nullable = true)
})
.toArray)
val outputFieldList: List[Field] = outputAttributes.toList.map(expr => {
val attr = ConverterUtils.getAttrFromExpr(expr)
Field
.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
Field.nullable(s"${attr.name.toLowerCase()}#${attr.exprId.id}",
CodeGeneration.getResultType(attr.dataType))
})
val arrowSchema = new Schema(outputFieldList.asJava)
var sort_iterator: BatchIterator = _
Expand Down Expand Up @@ -182,25 +182,33 @@ class ColumnarSorter(

object ColumnarSorter extends Logging {

def prepareRelationFunction(
sortOrder: Seq[SortOrder],
outputAttributes: Seq[Attribute]): TreeNode = {
def checkIfKeyFound(sortOrder: Seq[SortOrder], outputAttributes: Seq[Attribute]): Unit = {
val outputFieldList: List[Field] = outputAttributes.toList.map(expr => {
val attr = ConverterUtils.getAttrFromExpr(expr)
Field
.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
Field.nullable(s"${attr.name.toLowerCase()}#${attr.exprId.id}",
CodeGeneration.getResultType(attr.dataType))
})

val keyFieldList: List[Field] = sortOrder.toList.map(sort => {
sortOrder.toList.foreach(sort => {
val attr = ConverterUtils.getAttrFromExpr(sort.child)
val field = Field
.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
val field = Field.nullable(s"${attr.name.toLowerCase()}#${attr.exprId.id}",
CodeGeneration.getResultType(attr.dataType))
if (outputFieldList.indexOf(field) == -1) {
throw new UnsupportedOperationException(
s"ColumnarSorter not found ${attr.name}#${attr.exprId.id} in ${outputAttributes}")
s"ColumnarSorter not found ${attr.name.toLowerCase()}#${attr.exprId.id} " +
s"in ${outputAttributes}")
}
field
});
})
}

def prepareRelationFunction(
sortOrder: Seq[SortOrder],
outputAttributes: Seq[Attribute]): TreeNode = {
checkIfKeyFound(sortOrder, outputAttributes)
val keyFieldList: List[Field] = sortOrder.toList.map(sort => {
val attr = ConverterUtils.getAttrFromExpr(sort.child)
Field.nullable(s"${attr.name.toLowerCase()}#${attr.exprId.id}",
CodeGeneration.getResultType(attr.dataType))
})

val key_args_node = TreeBuilder.makeFunction(
"key_field",
Expand Down Expand Up @@ -229,25 +237,15 @@ object ColumnarSorter extends Logging {
sparkConf: SparkConf,
result_type: Int = 0): TreeNode = {
logInfo(s"ColumnarSorter sortOrder is ${sortOrder}, outputAttributes is ${outputAttributes}")
checkIfKeyFound(sortOrder, outputAttributes)
val NaNCheck = ColumnarPluginConfig.getConf.enableColumnarNaNCheck
val codegen = ColumnarPluginConfig.getConf.enableColumnarCodegenSort
/////////////// Prepare ColumnarSorter //////////////
val outputFieldList: List[Field] = outputAttributes.toList.map(expr => {
val attr = ConverterUtils.getAttrFromExpr(expr)
Field
.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
})

val keyFieldList: List[Field] = sortOrder.toList.map(sort => {
val attr = ConverterUtils.getAttrFromExpr(sort.child)
val field = Field
.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
if (outputFieldList.indexOf(field) == -1) {
throw new UnsupportedOperationException(
s"ColumnarSorter not found ${attr.name}#${attr.exprId.id} in ${outputAttributes}")
}
field
});
Field.nullable(s"${attr.name.toLowerCase()}#${attr.exprId.id}",
CodeGeneration.getResultType(attr.dataType))
})

/*
Get the sort directions and nulls order from SortOrder.
Expand Down Expand Up @@ -353,8 +351,8 @@ object ColumnarSorter extends Logging {
_sparkConf: SparkConf): (ExpressionTree, Schema) = {
val outputFieldList: List[Field] = outputAttributes.toList.map(expr => {
val attr = ConverterUtils.getAttrFromExpr(expr)
Field
.nullable(s"${attr.name}#${attr.exprId.id}", CodeGeneration.getResultType(attr.dataType))
Field.nullable(s"${attr.name.toLowerCase()}#${attr.exprId.id}",
CodeGeneration.getResultType(attr.dataType))
})
val retType = Field.nullable("res", new ArrowType.Int(32, true))
val sort_node =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class ColumnarIsNotNull(child: Expression, original: Expression)
FloatType,
DoubleType,
DateType,
TimestampType,
BooleanType,
StringType,
BinaryType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,18 @@ object ConverterUtils extends Logging {
}
}

def printBatch(cb: ColumnarBatch): Unit = {
var batch = ""
for (rowId <- 0 until cb.numRows()) {
var row = ""
for (colId <- 0 until cb.numCols()) {
row += (cb.column(colId).getUTF8String(rowId) + " ")
}
batch += (row + "\n")
}
logWarning(s"batch:\n$batch")
}

def getColumnarFuncNode(
expr: Expression,
attributes: Seq[Attribute] = null): (TreeNode, ArrowType) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
}

ignore("SPARK-8828 sum should return null if all input values are null") {
test("SPARK-8828 sum should return null if all input values are null") {
checkAnswer(
sql("select sum(a), avg(a) from allNulls"),
Seq(Row(null, null))
Expand Down Expand Up @@ -2832,7 +2832,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
checkAnswer(df, Row(1, 3, 4) :: Row(2, 3, 4) :: Row(3, 3, 4) :: Nil)
}

ignore("Support filter clause for aggregate function with hash aggregate") {
test("Support filter clause for aggregate function with hash aggregate") {
Seq(("COUNT(a)", 3), ("COLLECT_LIST(a)", Seq(1, 2, 3))).foreach { funcToResult =>
val query = s"SELECT ${funcToResult._1} FILTER (WHERE b > 1) FROM testData2"
val df = sql(query)
Expand Down Expand Up @@ -3734,7 +3734,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
}
}

ignore("SPARK-33677: LikeSimplification should be skipped if pattern contains any escapeChar") {
test("SPARK-33677: LikeSimplification should be skipped if pattern contains any escapeChar") {
withTempView("df") {
Seq("m@ca").toDF("s").createOrReplaceTempView("df")

Expand Down
Loading

0 comments on commit 8fe1253

Please sign in to comment.