Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-26901][SQL][R] Adds child's output into references to avoid co…
…lumn-pruning for vectorized gapply() ## What changes were proposed in this pull request? Currently, looks column pruning is done to vectorized `gapply()`. Given R native function could use all referred fields so it shouldn't be pruned. To avoid this, it adds child's output into `references` like `OutputConsumer`. ``` $ ./bin/sparkR --conf spark.sql.execution.arrow.enabled=true ``` ```r df <- createDataFrame(mtcars) explain(count(groupBy(gapply(df, "gear", function(key, group) { data.frame(gear = key[[1]], disp = mean(group$disp)) }, structType("gear double, disp double")))), TRUE) ``` **Before:** ``` == Optimized Logical Plan == Aggregate [count(1) AS count#41L] +- Project +- FlatMapGroupsInRWithArrow [...] +- Project [gear#9] +- LogicalRDD [mpg#0, cyl#1, disp#2, hp#3, drat#4, wt#5, qsec#6, vs#7, am#8, gear#9, carb#10], false == Physical Plan == *(4) HashAggregate(keys=[], functions=[count(1)], output=[count#41L]) +- Exchange SinglePartition +- *(3) HashAggregate(keys=[], functions=[partial_count(1)], output=[count#44L]) +- *(3) Project +- FlatMapGroupsInRWithArrow [...] +- *(2) Sort [gear#9 ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(gear#9, 200) +- *(1) Project [gear#9] +- *(1) Scan ExistingRDD arrow[mpg#0,cyl#1,disp#2,hp#3,drat#4,wt#5,qsec#6,vs#7,am#8,gear#9,carb#10] ``` **After:** ``` == Optimized Logical Plan == Aggregate [count(1) AS count#91L] +- Project +- FlatMapGroupsInRWithArrow [...] +- LogicalRDD [mpg#0, cyl#1, disp#2, hp#3, drat#4, wt#5, qsec#6, vs#7, am#8, gear#9, carb#10], false == Physical Plan == *(4) HashAggregate(keys=[], functions=[count(1)], output=[count#91L]) +- Exchange SinglePartition +- *(3) HashAggregate(keys=[], functions=[partial_count(1)], output=[count#94L]) +- *(3) Project +- FlatMapGroupsInRWithArrow [...] +- *(2) Sort [gear#9 ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(gear#9, 200) +- *(1) Scan ExistingRDD arrow[mpg#0,cyl#1,disp#2,hp#3,drat#4,wt#5,qsec#6,vs#7,am#8,gear#9,carb#10] ``` Currently, it adds corrupt values for missing columns (via pruned columnar batches to Arrow writers that requires non-pruned columns) such as: ```r ... c(7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 0, 0, 4.17777978645388e-314) c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1.04669129845114e+219) c(3.4482690635875e-313, 3.4482690635875e-313, 3.4482690635875e-313, c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2.47032822920623e-323) ... ``` which should be something like: ```r ... c(4, 4, 1, 2, 2, 4, 4, 1, 2, 1, 1, 2) c(26, 30.4, 15.8, 19.7, 15) c(4, 4, 8, 6, 8) c(120.3, 95.1, 351, 145, 301) ... ``` ## How was this patch tested? Manually tested, and unit tests were added. The test code is basiaclly: ```r df <- createDataFrame(mtcars) count(gapply(df, c("gear"), function(key, group) { stopifnot(all(group$hp > 50)) group }, schema(df))) ``` `mtcars`'s hp is all more then 50. ```r > mtcars$hp > 50 [1] TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE [16] TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE [31] TRUE TRUE ``` However, due to corrpt value, (like 0 or 7.xxxxx), werid values were found. So, it's currently being failed as below in the master: ``` Error in handleErrors(returnStatus, conn) : org.apache.spark.SparkException: Job aborted due to stage failure: Task 82 in stage 1.0 failed 1 times, most recent failure: Lost task 82.0 in stage 1.0 (TID 198, localhost, executor driver): org.apache.spark.SparkException: R worker exited unexpectedly (crashed) Error in computeFunc(key, inputData) : all(group$hp > 50) is not TRUE Error in computeFunc(key, inputData) : all(group$hp > 50) is not TRUE Error in computeFunc(key, inputData) : all(group$hp > 50) is not TRUE ``` I also compared the total length while I am here. Regular `gapply` without Arrow has some holes .. so I had to compare the results with R data frame. Closes apache#23810 from HyukjinKwon/SPARK-26901. Lead-authored-by: Hyukjin Kwon <[email protected]> Co-authored-by: Hyukjin Kwon <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
- Loading branch information