Skip to content

Commit

Permalink
support nested keys in GpuMapConcat (#6290)
Browse files Browse the repository at this point in the history
Signed-off-by: remzi <[email protected]>

Signed-off-by: remzi <[email protected]>
  • Loading branch information
HaoYang670 authored Aug 26, 2022
1 parent 3929002 commit 8f4b4cb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3202,10 +3202,6 @@ object GpuOverrides extends Logging {
}),
expr[MapConcat](
"Returns the union of all the given maps",
// Currently, GpuMapConcat supports nested values but not nested keys.
// We will add the nested key support after
// cuDF can fully support nested types in lists::drop_list_duplicates.
// Issue link: https://github.com/rapidsai/cudf/issues/11093
ExprChecks.projectOnly(TypeSig.MAP.nested(TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 +
TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
TypeSig.MAP.nested(TypeSig.all),
Expand All @@ -3214,13 +3210,6 @@ object GpuOverrides extends Logging {
TypeSig.NULL + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP),
TypeSig.MAP.nested(TypeSig.all)))),
(a, conf, p, r) => new ComplexTypeMergingExprMeta[MapConcat](a, conf, p, r) {
override def tagExprForGpu(): Unit = {
a.dataType.keyType match {
case MapType(_,_,_) | ArrayType(_,_) | StructType(_) => willNotWorkOnGpu(
s"GpuMapConcat does not currently support the key type ${a.dataType.keyType}.")
case _ =>
}
}
override def convertToGpu(child: Seq[Expression]): GpuExpression = GpuMapConcat(child)
}),
expr[ConcatWs](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,18 @@ package com.nvidia.spark.rapids
import org.apache.spark.sql.functions.map_concat

class CollectionOpSuite extends SparkQueryCompareTestSuite {
testGpuFallback(
"MapConcat with Array keys fall back",
"ProjectExec",
ArrayKeyMapDF,
execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec")) {
testSparkResultsAreEqual(
"MapConcat with Array keys",
ArrayKeyMapDF) {
frame => {
import frame.sparkSession.implicits._
frame.select(map_concat($"col1", $"col2"))
}
}

testGpuFallback(
"MapConcat with Struct keys fall back",
"ProjectExec",
StructKeyMapDF,
execsAllowedNonGpu = Seq("ProjectExec", "ShuffleExchangeExec")) {
testSparkResultsAreEqual(
"MapConcat with Struct keys",
StructKeyMapDF) {
frame => {
import frame.sparkSession.implicits._
frame.select(map_concat($"col1", $"col2"))
Expand Down

0 comments on commit 8f4b4cb

Please sign in to comment.