diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala index b8de30b1b06f7..3c68ac1bb84e5 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala @@ -19,6 +19,7 @@ package org.apache.gluten.execution import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.catalyst.optimizer.NullPropagation import org.apache.spark.sql.execution.ProjectExec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import java.sql.Timestamp @@ -1365,4 +1366,30 @@ abstract class ScalarFunctionsValidateSuite extends FunctionsValidateSuite { checkGlutenOperatorMatch[ProjectExecTransformer] } } + + testWithSpecifiedSparkVersion("array insert", Some("3.4")) { + withTempPath { + path => + Seq[Seq[Integer]](Seq(1, null, 5, 4), Seq(5, -1, 8, 9, -7, 2), Seq.empty, null) + .toDF("value") + .write + .parquet(path.getCanonicalPath) + + spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("array_tbl") + + Seq("true", "false").foreach { legacyNegativeIndex => + withSQLConf(SQLConf.LEGACY_NEGATIVE_INDEX_IN_ARRAY_INSERT.key -> legacyNegativeIndex) { + runQueryAndCompare( + """ + |select + | array_insert(value, 1, 0), array_insert(value, 10, 0), + | array_insert(value, -1, 0), array_insert(value, -10, 0) + |from array_tbl + |""".stripMargin) { + checkGlutenOperatorMatch[ProjectExecTransformer] + } + } + } + } + } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index c5ba3a8a78391..b1f9fac849c46 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -620,6 +620,14 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformer0(a.function, attributeSeq, expressionsMap), a ) + case arrayInsert if arrayInsert.getClass.getSimpleName.equals("ArrayInsert") => + // Since spark 3.4.0 + val children = SparkShimLoader.getSparkShims.extractExpressionArrayInsert(arrayInsert) + GenericExpressionTransformer( + substraitExprName, + children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)), + arrayInsert + ) case s: Shuffle => GenericExpressionTransformer( substraitExprName, diff --git a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala index 96a615615179c..f198bb7e17c96 100644 --- a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala +++ b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala @@ -272,6 +272,7 @@ object ExpressionNames { final val SHUFFLE = "shuffle" final val ZIP_WITH = "zip_with" final val FLATTEN = "flatten" + final val ARRAY_INSERT = "array_insert" // Map functions final val CREATE_MAP = "map" diff --git a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala index fa6ed18e9fa8b..7671f236c9170 100644 --- a/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala +++ b/shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala @@ -266,4 +266,8 @@ trait SparkShims { DecimalType(math.min(integralLeastNumDigits + newScale, 38), newScale) } } + + def extractExpressionArrayInsert(arrayInsert: Expression): Seq[Expression] = { + throw new UnsupportedOperationException("ArrayInsert not supported.") + } } diff --git a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala index b277139e8300d..8a53fe8367ba6 100644 --- a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala +++ b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala @@ -492,4 +492,9 @@ class Spark34Shims extends SparkShims { RebaseSpec(LegacyBehaviorPolicy.CORRECTED) ) } + + override def extractExpressionArrayInsert(arrayInsert: Expression): Seq[Expression] = { + val expr = arrayInsert.asInstanceOf[ArrayInsert] + Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr, Literal(expr.legacyNegativeIndex)) + } } diff --git a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala index 6474c74fe8f3b..d2bf3776194ed 100644 --- a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala +++ b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala @@ -517,4 +517,9 @@ class Spark35Shims extends SparkShims { RebaseSpec(LegacyBehaviorPolicy.CORRECTED) ) } + + override def extractExpressionArrayInsert(arrayInsert: Expression): Seq[Expression] = { + val expr = arrayInsert.asInstanceOf[ArrayInsert] + Seq(expr.srcArrayExpr, expr.posExpr, expr.itemExpr, Literal(expr.legacyNegativeIndex)) + } }