Skip to content

Commit

Permalink
[SPARK-26901][SQL][R] Adds child's output into references to avoid co…
Browse files Browse the repository at this point in the history
…lumn-pruning for vectorized gapply()

## What changes were proposed in this pull request?

Currently, looks column pruning is done to vectorized `gapply()`. Given R native function could use all referred fields so it shouldn't be pruned. To avoid this, it adds child's output into `references` like `OutputConsumer`.

```
$ ./bin/sparkR --conf spark.sql.execution.arrow.enabled=true
```

```r
df <- createDataFrame(mtcars)
explain(count(groupBy(gapply(df,
                             "gear",
                             function(key, group) {
                               data.frame(gear = key[[1]], disp = mean(group$disp))
                             },
                             structType("gear double, disp double")))), TRUE)
```

**Before:**

```
== Optimized Logical Plan ==
Aggregate [count(1) AS count#41L]
+- Project
   +- FlatMapGroupsInRWithArrow [...]
      +- Project [gear#9]
         +- LogicalRDD [mpg#0, cyl#1, disp#2, hp#3, drat#4, wt#5, qsec#6, vs#7, am#8, gear#9, carb#10], false

== Physical Plan ==
*(4) HashAggregate(keys=[], functions=[count(1)], output=[count#41L])
+- Exchange SinglePartition
   +- *(3) HashAggregate(keys=[], functions=[partial_count(1)], output=[count#44L])
      +- *(3) Project
         +- FlatMapGroupsInRWithArrow [...]
            +- *(2) Sort [gear#9 ASC NULLS FIRST], false, 0
               +- Exchange hashpartitioning(gear#9, 200)
                  +- *(1) Project [gear#9]
                     +- *(1) Scan ExistingRDD arrow[mpg#0,cyl#1,disp#2,hp#3,drat#4,wt#5,qsec#6,vs#7,am#8,gear#9,carb#10]
```

**After:**

```
== Optimized Logical Plan ==
Aggregate [count(1) AS count#91L]
+- Project
   +- FlatMapGroupsInRWithArrow [...]
      +- LogicalRDD [mpg#0, cyl#1, disp#2, hp#3, drat#4, wt#5, qsec#6, vs#7, am#8, gear#9, carb#10], false

== Physical Plan ==
*(4) HashAggregate(keys=[], functions=[count(1)], output=[count#91L])
+- Exchange SinglePartition
   +- *(3) HashAggregate(keys=[], functions=[partial_count(1)], output=[count#94L])
      +- *(3) Project
         +- FlatMapGroupsInRWithArrow [...]
            +- *(2) Sort [gear#9 ASC NULLS FIRST], false, 0
               +- Exchange hashpartitioning(gear#9, 200)
                  +- *(1) Scan ExistingRDD arrow[mpg#0,cyl#1,disp#2,hp#3,drat#4,wt#5,qsec#6,vs#7,am#8,gear#9,carb#10]
```

Currently, it adds corrupt values for missing columns (via pruned columnar batches to Arrow writers that requires non-pruned columns) such as:

```r
...
  c(7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 0, 0, 4.17777978645388e-314)
  c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1.04669129845114e+219)
  c(3.4482690635875e-313, 3.4482690635875e-313, 3.4482690635875e-313,
  c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2.47032822920623e-323)
...
```

which should be something like:

```r
...
  c(4, 4, 1, 2, 2, 4, 4, 1, 2, 1, 1, 2)
  c(26, 30.4, 15.8, 19.7, 15)
  c(4, 4, 8, 6, 8)
  c(120.3, 95.1, 351, 145, 301)
...
```

## How was this patch tested?

Manually tested, and unit tests were added.

The test code is basiaclly:

```r
df <- createDataFrame(mtcars)
count(gapply(df,
             c("gear"),
             function(key, group) {
                stopifnot(all(group$hp > 50))
                group
             },
             schema(df)))
```

`mtcars`'s hp is all more then 50.

```r
> mtcars$hp > 50
 [1] TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
[16] TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
[31] TRUE TRUE
```

However, due to corrpt value, (like 0 or 7.xxxxx), werid values were found. So, it's currently being failed as below in the master:

```
Error in handleErrors(returnStatus, conn) :
  org.apache.spark.SparkException: Job aborted due to stage failure: Task 82 in stage 1.0 failed 1 times, most recent failure: Lost task 82.0 in stage 1.0 (TID 198, localhost, executor driver): org.apache.spark.SparkException: R worker exited unexpectedly (crashed)
 Error in computeFunc(key, inputData) : all(group$hp > 50) is not TRUE
Error in computeFunc(key, inputData) : all(group$hp > 50) is not TRUE
Error in computeFunc(key, inputData) : all(group$hp > 50) is not TRUE
```

I also compared the total length while I am here. Regular `gapply` without Arrow has some holes .. so I had to compare the results with R data frame.

Closes apache#23810 from HyukjinKwon/SPARK-26901.

Lead-authored-by: Hyukjin Kwon <[email protected]>
Co-authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
2 people authored and rshkv committed Jun 29, 2020
1 parent 9210793 commit 768ae42
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
4 changes: 4 additions & 0 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -3507,11 +3507,15 @@ test_that("gapply() Arrow optimization", {
stopifnot(is.numeric(key[[1]]))
}
stopifnot(class(grouped) == "data.frame")
stopifnot(length(colnames(grouped)) == 11)
# mtcars' hp is more then 50.
stopifnot(all(grouped$hp > 50))
grouped
},
schema(df))
actual <- collect(ret)
expect_equal(actual, expected)
expect_equal(count(ret), nrow(mtcars))
},
finally = {
# Resetting the conf back to default value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,8 @@ case class FlatMapGroupsInRWithArrow(
keyDeserializer: Expression,
groupingAttributes: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
// This operator always need all columns of its child, even it doesn't reference to.
override def references: AttributeSet = child.outputSet

override protected def stringArgs: Iterator[Any] = Iterator(
inputSchema, StructType.fromAttributes(output), keyDeserializer, groupingAttributes, child)
Expand Down

0 comments on commit 768ae42

Please sign in to comment.