Skip to content

Commit

Permalink
Add array insert function for spark 3.4+
Browse files Browse the repository at this point in the history
  • Loading branch information
ivoson committed Sep 5, 2024
1 parent 376167e commit 71aa28b
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.gluten.execution
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.catalyst.optimizer.NullPropagation
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import java.sql.Timestamp
Expand Down Expand Up @@ -1365,4 +1366,30 @@ abstract class ScalarFunctionsValidateSuite extends FunctionsValidateSuite {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
}

testWithSpecifiedSparkVersion("array insert", Some("3.4")) {
withTempPath {
path =>
Seq[Seq[Integer]](Seq(1, null, 5, 4), Seq(5, -1, 8, 9, -7, 2), Seq.empty, null)
.toDF("value")
.write
.parquet(path.getCanonicalPath)

spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("array_tbl")

Seq("true", "false").foreach { legacyNegativeIndex =>
withSQLConf(SQLConf.LEGACY_NEGATIVE_INDEX_IN_ARRAY_INSERT.key -> legacyNegativeIndex) {
runQueryAndCompare(
"""
|select
| array_insert(value, 1, 0), array_insert(value, 10, 0),
| array_insert(value, -1, 0), array_insert(value, -10, 0)
|from array_tbl
|""".stripMargin) {
checkGlutenOperatorMatch[ProjectExecTransformer]
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,14 @@ object ExpressionConverter extends SQLConfHelper with Logging {
replaceWithExpressionTransformer0(a.function, attributeSeq, expressionsMap),
a
)
case arrayInsert if arrayInsert.getClass.getSimpleName.equals("ArrayInsert") =>
// Since spark 3.4.0
val children = SparkShimLoader.getSparkShims.extractExpressionArrayInsert(arrayInsert)
GenericExpressionTransformer(
substraitExprName,
children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)),
arrayInsert
)
case s: Shuffle =>
GenericExpressionTransformer(
substraitExprName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ object ExpressionNames {
final val SHUFFLE = "shuffle"
final val ZIP_WITH = "zip_with"
final val FLATTEN = "flatten"
final val ARRAY_INSERT = "array_insert"

// Map functions
final val CREATE_MAP = "map"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,4 +266,8 @@ trait SparkShims {
DecimalType(math.min(integralLeastNumDigits + newScale, 38), newScale)
}
}

def extractExpressionArrayInsert(arrayInsert: Expression): Seq[Expression] = {
throw new UnsupportedOperationException("ArrayInsert not supported.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -492,4 +492,9 @@ class Spark34Shims extends SparkShims {
RebaseSpec(LegacyBehaviorPolicy.CORRECTED)
)
}

override def extractExpressionArrayInsert(arrayInsert: Expression): Seq[Expression] = {
val expr = arrayInsert.asInstanceOf[ArrayInsert]
Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr, Literal(expr.legacyNegativeIndex))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -517,4 +517,9 @@ class Spark35Shims extends SparkShims {
RebaseSpec(LegacyBehaviorPolicy.CORRECTED)
)
}

override def extractExpressionArrayInsert(arrayInsert: Expression): Seq[Expression] = {
val expr = arrayInsert.asInstanceOf[ArrayInsert]
Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr, Literal(expr.legacyNegativeIndex))
}
}

0 comments on commit 71aa28b

Please sign in to comment.