From 8126d09fb5b969c1e293f1f8c41bec35357f74b5 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 13 Feb 2019 11:19:58 +0800 Subject: [PATCH] [SPARK-26761][SQL][R] Vectorized R gapply() implementation ## What changes were proposed in this pull request? This PR targets to add vectorized `gapply()` in R, Arrow optimization. This can be tested as below: ```bash $ ./bin/sparkR --conf spark.sql.execution.arrow.enabled=true ``` ```r df <- createDataFrame(mtcars) collect(gapply(df, "gear", function(key, group) { data.frame(gear = key[[1]], disp = mean(group$disp) > group$disp) }, structType("gear double, disp boolean"))) ``` ### Requirements - R 3.5.x - Arrow package 0.12+ ```bash Rscript -e 'remotes::install_github("apache/arrowapache-arrow-0.12.0", subdir = "r")' ``` **Note:** currently, Arrow R package is not in CRAN. Please take a look at ARROW-3204. **Note:** currently, Arrow R package seems not supporting Windows. Please take a look at ARROW-3204. ### Benchmarks **Shall** ```bash sync && sudo purge ./bin/sparkR --conf spark.sql.execution.arrow.enabled=false ``` ```bash sync && sudo purge ./bin/sparkR --conf spark.sql.execution.arrow.enabled=true ``` **R code** ```r rdf <- read.csv("500000.csv") rdf <- rdf[, c("Month.of.Joining", "Weight.in.Kgs.")] # We're only interested in the key and values to calculate. df <- cache(createDataFrame(rdf)) count(df) test <- function() { options(digits.secs = 6) # milliseconds start.time <- Sys.time() count(gapply(df, "Month_of_Joining", function(key, group) { data.frame(Month_of_Joining = key[[1]], Weight_in_Kgs_ = mean(group$Weight_in_Kgs_) > group$Weight_in_Kgs_) }, structType("Month_of_Joining integer, Weight_in_Kgs_ boolean"))) end.time <- Sys.time() time.taken <- end.time - start.time print(time.taken) } test() ``` **Data (350 MB):** ```r object.size(read.csv("500000.csv")) 350379504 bytes ``` "500000 Records" http://eforexcel.com/wp/downloads-16-sample-csv-files-data-sets-for-testing/ **Results** ``` Time difference of 35.67459 secs ``` ``` Time difference of 4.301399 secs ``` The performance improvement was around **829%**. **Note that** I am 100% sure this PR improves more then 829% because I gave up testing it with non-Arrow optimization because it took super super super long when the data size becomes bigger. ### Limitations - For now, Arrow optimization with R does not support when the data is `raw`, and when user explicitly gives float type in the schema. They produce corrupt values. - Due to ARROW-4512, it cannot send and receive batch by batch. It has to send all batches in Arrow stream format at once. It needs improvement later. ## How was this patch tested? Unit tests were added **TODOs:** - [x] Draft codes - [x] make the tests passed - [x] make the CRAN check pass - [x] Performance measurement - [x] Supportability investigation (for instance types) Closes #23746 from HyukjinKwon/SPARK-26759. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- R/pkg/R/deserialize.R | 27 +++ R/pkg/R/group.R | 23 ++ R/pkg/inst/worker/worker.R | 33 ++- R/pkg/tests/fulltests/test_sparkSQL.R | 110 ++++++++++ .../org/apache/spark/api/r/RRunner.scala | 116 +++++----- .../sql/catalyst/plans/logical/object.scala | 63 ++++-- .../spark/sql/execution/SparkStrategies.scala | 3 + .../apache/spark/sql/execution/objects.scala | 68 ++++++ .../spark/sql/execution/r/ArrowRRunner.scala | 205 ++++++++++++++++++ 9 files changed, 578 insertions(+), 70 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index cb03f1667629f..4c5d2bcb9f035 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -231,6 +231,33 @@ readMultipleObjectsWithKeys <- function(inputCon) { list(keys = keys, data = data) # this is a list of keys and corresponding data } +readDeserializeInArrow <- function(inputCon) { + # This is a hack to avoid CRAN check. Arrow is not uploaded into CRAN now. See ARROW-3204. + requireNamespace1 <- requireNamespace + if (requireNamespace1("arrow", quietly = TRUE)) { + RecordBatchStreamReader <- get( + "RecordBatchStreamReader", envir = asNamespace("arrow"), inherits = FALSE) + as_tibble <- get("as_tibble", envir = asNamespace("arrow")) + + # Currently, there looks no way to read batch by batch by socket connection in R side, + # See ARROW-4512. Therefore, it reads the whole Arrow streaming-formatted binary at once + # for now. + dataLen <- readInt(inputCon) + arrowData <- readBin(inputCon, raw(), as.integer(dataLen), endian = "big") + batches <- RecordBatchStreamReader(arrowData)$batches() + + # Read all groupped batches. Tibble -> data.frame is cheap. + data <- lapply(batches, function(batch) as.data.frame(as_tibble(batch))) + + # Read keys to map with each groupped batch. + keys <- readMultipleObjects(inputCon) + + list(keys = keys, data = data) + } else { + stop("'arrow' package should be installed.") + } +} + readRowList <- function(obj) { # readRowList is meant for use inside an lapply. As a result, it is # necessary to open a standalone connection for the row and consume diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index f751b952f3915..32592f92b325f 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -229,6 +229,29 @@ gapplyInternal <- function(x, func, schema) { if (is.character(schema)) { schema <- structType(schema) } + arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] == "true" + if (arrowEnabled) { + requireNamespace1 <- requireNamespace + if (!requireNamespace1("arrow", quietly = TRUE)) { + stop("'arrow' package should be installed.") + } + # Currenty Arrow optimization does not support raw for now. + # Also, it does not support explicit float type set by users. + if (inherits(schema, "structType")) { + if (any(sapply(schema$fields(), function(x) x$dataType.toString() == "FloatType"))) { + stop("Arrow optimization with gapply do not support FloatType yet.") + } + if (any(sapply(schema$fields(), function(x) x$dataType.toString() == "BinaryType"))) { + stop("Arrow optimization with gapply do not support BinaryType yet.") + } + } else if (is.null(schema)) { + stop(paste0("Arrow optimization does not support gapplyCollect yet. Please use ", + "'collect' and 'gapply' APIs instead.")) + } else { + stop("'schema' should be DDL-formatted string or structType.") + } + } + packageNamesArr <- serialize(.sparkREnv[[".packages"]], connection = NULL) broadcastArr <- lapply(ls(.broadcastNames), diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index c2adf613acb02..eed4c843f9a3b 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -49,7 +49,7 @@ compute <- function(mode, partition, serializer, deserializer, key, names(inputData) <- colNames } else { # Check to see if inputData is a valid data.frame - stopifnot(deserializer == "byte") + stopifnot(deserializer == "byte" || deserializer == "arrow") stopifnot(class(inputData) == "data.frame") } @@ -63,7 +63,7 @@ compute <- function(mode, partition, serializer, deserializer, key, output <- split(output, seq(nrow(output))) } else { # Serialize the output to a byte array - stopifnot(serializer == "byte") + stopifnot(serializer == "byte" || serializer == "arrow") } } else { output <- computeFunc(partition, inputData) @@ -171,6 +171,10 @@ if (isEmpty != 0) { data <- dataWithKeys$data } else if (deserializer == "row") { data <- SparkR:::readMultipleObjects(inputCon) + } else if (deserializer == "arrow" && mode == 2) { + dataWithKeys <- SparkR:::readDeserializeInArrow(inputCon) + keys <- dataWithKeys$keys + data <- dataWithKeys$data } # Timing reading input data for execution @@ -181,17 +185,40 @@ if (isEmpty != 0) { colNames, computeFunc, data) } else { # gapply mode + outputs <- list() for (i in 1:length(data)) { # Timing reading input data for execution inputElap <- elapsedSecs() output <- compute(mode, partition, serializer, deserializer, keys[[i]], colNames, computeFunc, data[[i]]) computeElap <- elapsedSecs() - outputResult(serializer, output, outputCon) + if (deserializer == "arrow") { + outputs[[length(outputs) + 1L]] <- output + } else { + outputResult(serializer, output, outputCon) + } outputElap <- elapsedSecs() computeInputElapsDiff <- computeInputElapsDiff + (computeElap - inputElap) outputComputeElapsDiff <- outputComputeElapsDiff + (outputElap - computeElap) } + + if (deserializer == "arrow") { + # This is a hack to avoid CRAN check. Arrow is not uploaded into CRAN now. See ARROW-3204. + requireNamespace1 <- requireNamespace + if (requireNamespace1("arrow", quietly = TRUE)) { + write_arrow <- get("write_arrow", envir = asNamespace("arrow"), inherits = FALSE) + # See https://stat.ethz.ch/pipermail/r-help/2010-September/252046.html + # rbind.fill might be an anternative to make it faster if plyr is installed. + combined <- do.call("rbind", outputs) + + # Likewise, there looks no way to send each batch in streaming format via socket + # connection. See ARROW-4512. + # So, it writes the whole Arrow streaming-formatted binary at once for now. + SparkR:::writeRaw(outputCon, write_arrow(combined, raw())) + } else { + stop("'arrow' package should be installed.") + } + } } } else { output <- compute(mode, partition, serializer, deserializer, NULL, diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index a5dde2084da71..9dc699c09a1e4 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -3529,6 +3529,116 @@ test_that("gapply() and gapplyCollect() on a DataFrame", { }) }) +test_that("gapply() Arrow optimization", { + skip_if_not_installed("arrow") + df <- createDataFrame(mtcars) + + conf <- callJMethod(sparkSession, "conf") + arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] + + callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "false") + tryCatch({ + ret <- gapply(df, + "gear", + function(key, grouped) { + if (length(key) > 0) { + stopifnot(is.numeric(key[[1]])) + } + stopifnot(class(grouped) == "data.frame") + grouped + }, + schema(df)) + expected <- collect(ret) + }, + finally = { + # Resetting the conf back to default value + callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled) + }) + + callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true") + tryCatch({ + ret <- gapply(df, + "gear", + function(key, grouped) { + if (length(key) > 0) { + stopifnot(is.numeric(key[[1]])) + } + stopifnot(class(grouped) == "data.frame") + grouped + }, + schema(df)) + actual <- collect(ret) + expect_equal(actual, expected) + }, + finally = { + # Resetting the conf back to default value + callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled) + }) +}) + +test_that("gapply() Arrow optimization - type specification", { + skip_if_not_installed("arrow") + # Note that regular gapply() seems not supporting date and timestamps + # whereas Arrow-optimized gapply() does. + rdf <- data.frame(list(list(a = 1, + b = "a", + c = TRUE, + d = 1.1, + e = 1L))) + df <- createDataFrame(rdf) + + conf <- callJMethod(sparkSession, "conf") + arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] + + callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "false") + tryCatch({ + ret <- gapply(df, + "a", + function(key, grouped) { grouped }, schema(df)) + expected <- collect(ret) + }, + finally = { + # Resetting the conf back to default value + callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled) + }) + + + callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true") + tryCatch({ + ret <- gapply(df, + "a", + function(key, grouped) { grouped }, schema(df)) + actual <- collect(ret) + expect_equal(actual, expected) + }, + finally = { + # Resetting the conf back to default value + callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled) + }) +}) + +test_that("gapply() Arrow optimization - type specification (date and timestamp)", { + skip_if_not_installed("arrow") + rdf <- data.frame(list(list(a = as.Date("1990-02-24"), + b = as.POSIXct("1990-02-24 12:34:56")))) + df <- createDataFrame(rdf) + + conf <- callJMethod(sparkSession, "conf") + arrowEnabled <- sparkR.conf("spark.sql.execution.arrow.enabled")[[1]] + + callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", "true") + tryCatch({ + ret <- gapply(df, + "a", + function(key, grouped) { grouped }, schema(df)) + expect_equal(collect(ret), rdf) + }, + finally = { + # Resetting the conf back to default value + callJMethod(conf, "set", "spark.sql.execution.arrow.enabled", arrowEnabled) + }) +}) + test_that("Window functions on a DataFrame", { df <- createDataFrame(list(list(1L, "1"), list(2L, "2"), list(1L, "1"), list(2L, "2")), schema = c("key", "value")) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index b367c7ffb1590..971d11f84173a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -45,7 +45,7 @@ private[spark] class RRunner[U]( colNames: Array[String] = null, mode: Int = RRunnerModes.RDD) extends Logging { - private var bootTime: Double = _ + protected var bootTime: Double = _ private var dataStream: DataInputStream = _ val readData = numPartitions match { case -1 => @@ -91,28 +91,70 @@ private[spark] class RRunner[U]( } try { - return new Iterator[U] { - def next(): U = { - val obj = _nextObj - if (hasNext) { - _nextObj = read() - } - obj + newReaderIterator(dataStream, errThread) + } catch { + case e: Exception => + throw new SparkException("R computation failed with\n " + errThread.getLines(), e) + } + } + + protected def newReaderIterator( + dataStream: DataInputStream, errThread: BufferedStreamThread): Iterator[U] = { + new Iterator[U] { + def next(): U = { + val obj = _nextObj + if (hasNext()) { + _nextObj = read() } + obj + } - var _nextObj = read() + private var _nextObj = read() - def hasNext(): Boolean = { - val hasMore = (_nextObj != null) - if (!hasMore) { - dataStream.close() - } - hasMore + def hasNext(): Boolean = { + val hasMore = _nextObj != null + if (!hasMore) { + dataStream.close() } + hasMore + } + } + } + + protected def writeData( + dataOut: DataOutputStream, + printOut: PrintStream, + iter: Iterator[_]): Unit = { + def writeElem(elem: Any): Unit = { + if (deserializer == SerializationFormats.BYTE) { + val elemArr = elem.asInstanceOf[Array[Byte]] + dataOut.writeInt(elemArr.length) + dataOut.write(elemArr) + } else if (deserializer == SerializationFormats.ROW) { + dataOut.write(elem.asInstanceOf[Array[Byte]]) + } else if (deserializer == SerializationFormats.STRING) { + // write string(for StringRRDD) + // scalastyle:off println + printOut.println(elem) + // scalastyle:on println + } + } + + for (elem <- iter) { + elem match { + case (key, innerIter: Iterator[_]) => + for (innerElem <- innerIter) { + writeElem(innerElem) + } + // Writes key which can be used as a boundary in group-aggregate + dataOut.writeByte('r') + writeElem(key) + case (key, value) => + writeElem(key) + writeElem(value) + case _ => + writeElem(elem) } - } catch { - case e: Exception => - throw new SparkException("R computation failed with\n " + errThread.getLines()) } } @@ -171,37 +213,7 @@ private[spark] class RRunner[U]( val printOut = new PrintStream(stream) - def writeElem(elem: Any): Unit = { - if (deserializer == SerializationFormats.BYTE) { - val elemArr = elem.asInstanceOf[Array[Byte]] - dataOut.writeInt(elemArr.length) - dataOut.write(elemArr) - } else if (deserializer == SerializationFormats.ROW) { - dataOut.write(elem.asInstanceOf[Array[Byte]]) - } else if (deserializer == SerializationFormats.STRING) { - // write string(for StringRRDD) - // scalastyle:off println - printOut.println(elem) - // scalastyle:on println - } - } - - for (elem <- iter) { - elem match { - case (key, innerIter: Iterator[_]) => - for (innerElem <- innerIter) { - writeElem(innerElem) - } - // Writes key which can be used as a boundary in group-aggregate - dataOut.writeByte('r') - writeElem(key) - case (key, value) => - writeElem(key) - writeElem(value) - case _ => - writeElem(elem) - } - } + writeData(dataOut, printOut, iter) stream.flush() } catch { @@ -261,7 +273,7 @@ private[spark] class RRunner[U]( } } - private def readByteArrayData(length: Int): Array[Byte] = { + protected def readByteArrayData(length: Int): Array[Byte] = { length match { case length if length > 0 => val obj = new Array[Byte](length) @@ -280,7 +292,7 @@ private[spark] class RRunner[U]( } } -private object SpecialLengths { +private[spark] object SpecialLengths { val TIMING_DATA = -1 } @@ -290,7 +302,7 @@ private[spark] object RRunnerModes { val DATAFRAME_GAPPLY = 2 } -private[r] class BufferedStreamThread( +private[spark] class BufferedStreamThread( in: InputStream, name: String, errBufferSize: Int) extends Thread(name) with Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index bfb70c2ef4c89..58bb1915b3c72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -26,7 +26,8 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode } +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -437,19 +438,30 @@ object FlatMapGroupsInR { groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], child: LogicalPlan): LogicalPlan = { - val mapped = FlatMapGroupsInR( - func, - packageNames, - broadcastVars, - inputSchema, - schema, - UnresolvedDeserializer(keyDeserializer, groupingAttributes), - UnresolvedDeserializer(valueDeserializer, dataAttributes), - groupingAttributes, - dataAttributes, - CatalystSerde.generateObjAttr(RowEncoder(schema)), - child) - CatalystSerde.serialize(mapped)(RowEncoder(schema)) + if (SQLConf.get.arrowEnabled) { + FlatMapGroupsInRWithArrow( + func, + packageNames, + broadcastVars, + inputSchema, + schema.toAttributes, + UnresolvedDeserializer(keyDeserializer, groupingAttributes), + groupingAttributes, + child) + } else { + CatalystSerde.serialize(FlatMapGroupsInR( + func, + packageNames, + broadcastVars, + inputSchema, + schema, + UnresolvedDeserializer(keyDeserializer, groupingAttributes), + UnresolvedDeserializer(valueDeserializer, dataAttributes), + groupingAttributes, + dataAttributes, + CatalystSerde.generateObjAttr(RowEncoder(schema)), + child))(RowEncoder(schema)) + } } } @@ -464,7 +476,7 @@ case class FlatMapGroupsInR( groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputObjAttr: Attribute, - child: LogicalPlan) extends UnaryNode with ObjectProducer{ + child: LogicalPlan) extends UnaryNode with ObjectProducer { override lazy val schema = outputSchema @@ -473,6 +485,27 @@ case class FlatMapGroupsInR( child) } +/** + * Similar with `FlatMapGroupsInR` but serializes and deserializes input/output in + * Arrow format. + * This is also somewhat similar with [[FlatMapGroupsInPandas]]. + */ +case class FlatMapGroupsInRWithArrow( + func: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + inputSchema: StructType, + output: Seq[Attribute], + keyDeserializer: Expression, + groupingAttributes: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + + override protected def stringArgs: Iterator[Any] = Iterator( + inputSchema, StructType.fromAttributes(output), keyDeserializer, groupingAttributes, child) + + override val producedAttributes = AttributeSet(output) +} + /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { def apply[K : Encoder, L : Encoder, R : Encoder, OUT : Encoder]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index edfa70403ad15..6827ba6fe474c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -596,6 +596,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) => execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping, data, objAttr, planLater(child)) :: Nil + case logical.FlatMapGroupsInRWithArrow(f, p, b, is, ot, key, grouping, child) => + execution.FlatMapGroupsInRWithArrowExec( + f, p, b, is, ot, key, grouping, planLater(child)) :: Nil case logical.FlatMapGroupsInPandas(grouping, func, output, child) => execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil case logical.MapElements(f, _, _, objAttr, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 03d1bbf2ab882..dd76efd22a10c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import scala.collection.JavaConverters._ import scala.language.existentials import org.apache.spark.api.java.function.MapFunction @@ -31,7 +32,9 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, FunctionUtils, LogicalGroupState} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.r.ArrowRRunner import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.GroupStateTimeout import org.apache.spark.sql.types._ @@ -437,6 +440,71 @@ case class FlatMapGroupsInRExec( } } +/** + * Similar with [[FlatMapGroupsInRExec]] but serializes and deserializes input/output in + * Arrow format. + * This is also somewhat similar with + * [[org.apache.spark.sql.execution.python.FlatMapGroupsInPandasExec]]. + */ +case class FlatMapGroupsInRWithArrowExec( + func: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + inputSchema: StructType, + output: Seq[Attribute], + keyDeserializer: Expression, + groupingAttributes: Seq[Attribute], + child: SparkPlan) extends UnaryExecNode { + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def requiredChildDistribution: Seq[Distribution] = + if (groupingAttributes.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingAttributes) :: Nil + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val grouped = GroupedIterator(iter, groupingAttributes, child.output) + val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) + val runner = new ArrowRRunner( + func, packageNames, broadcastVars, inputSchema, SQLConf.get.sessionLocalTimeZone) + + val groupedByRKey = grouped.map { case (key, rowIter) => + val newKey = rowToRBytes(getKey(key).asInstanceOf[Row]) + (newKey, rowIter) + } + + // The communication mechanism is as follows: + // + // JVM side R side + // + // 1. Group internal rows + // 2. Grouped internal rows --------> Arrow record batches + // 3. Grouped keys --------> Regular serialized keys + // 4. Converts each Arrow record batch to each R data frame + // 5. Deserializes keys + // 6. Maps each key to each R Data frame + // 7. Computes R native function on each key/R data frame + // 8. Converts all R data frames to Arrow record batches + // 9. Columnar batches <-------- Arrow record batches + // 10. Each row from each batch + // + // Note that, unlike Python vectorization implementation, R side sends Arrow formatted + // binary in a batch due to the limitation of R API. See also ARROW-4512. + val columnarBatchIter = runner.compute(groupedByRKey, -1) + val outputProject = UnsafeProjection.create(output, output) + columnarBatchIter.flatMap(_.rowIterator().asScala).map(outputProject) + } + } +} + /** * Co-groups the data from left and right children, and calls the function with each group and 2 * iterators containing all elements in the group from left and right side. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala new file mode 100644 index 0000000000000..a8d0bf17c7a68 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.r + +import java.io._ +import java.nio.channels.Channels + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} +import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.api.r._ +import org.apache.spark.api.r.SpecialLengths +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.util.Utils + + +/** + * Similar to `ArrowPythonRunner`, but exchange data with R worker via Arrow stream. + */ +class ArrowRRunner( + func: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + schema: StructType, + timeZoneId: String) + extends RRunner[ColumnarBatch]( + func, + "arrow", + "arrow", + packageNames, + broadcastVars, + numPartitions = -1, + isDataFrame = true, + schema.fieldNames, + RRunnerModes.DATAFRAME_GAPPLY) { + + protected override def writeData( + dataOut: DataOutputStream, + printOut: PrintStream, + iter: Iterator[_]): Unit = if (iter.hasNext) { + val inputIterator = iter.asInstanceOf[Iterator[(Array[Byte], Iterator[InternalRow])]] + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + "stdout writer for R", 0, Long.MaxValue) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val out = new ByteArrayOutputStream() + val keys = collection.mutable.ArrayBuffer.empty[Array[Byte]] + + Utils.tryWithSafeFinally { + val arrowWriter = ArrowWriter.create(root) + val writer = new ArrowStreamWriter(root, null, Channels.newChannel(out)) + writer.start() + + while (inputIterator.hasNext) { + val (key, nextBatch) = inputIterator.next() + keys.append(key) + + while (nextBatch.hasNext) { + arrowWriter.write(nextBatch.next()) + } + + arrowWriter.finish() + writer.writeBatch() + arrowWriter.reset() + } + writer.end() + } { + // Don't close root and allocator in TaskCompletionListener to prevent + // a race condition. See `ArrowPythonRunner`. + root.close() + allocator.close() + } + + // Currently, there looks no way to read batch by batch by socket connection in R side, + // See ARROW-4512. Therefore, it writes the whole Arrow streaming-formatted binary at + // once for now. + val data = out.toByteArray + dataOut.writeInt(data.length) + dataOut.write(data) + + keys.foreach(dataOut.write) + } + + protected override def newReaderIterator( + dataStream: DataInputStream, errThread: BufferedStreamThread): Iterator[ColumnarBatch] = { + new Iterator[ColumnarBatch] { + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + "stdin reader for R", 0, Long.MaxValue) + + private var reader: ArrowStreamReader = _ + private var root: VectorSchemaRoot = _ + private var vectors: Array[ColumnVector] = _ + + TaskContext.get().addTaskCompletionListener[Unit] { _ => + if (reader != null) { + reader.close(false) + } + allocator.close() + } + + private var batchLoaded = true + private var nextObj: ColumnarBatch = _ + private var eos = false + + override def hasNext: Boolean = nextObj != null || { + if (!eos) { + nextObj = read() + hasNext + } else { + false + } + } + + override def next(): ColumnarBatch = { + if (hasNext) { + val obj = nextObj + nextObj = null.asInstanceOf[ColumnarBatch] + obj + } else { + Iterator.empty.next() + } + } + + private def read(): ColumnarBatch = try { + if (reader != null && batchLoaded) { + batchLoaded = reader.loadNextBatch() + if (batchLoaded) { + val batch = new ColumnarBatch(vectors) + batch.setNumRows(root.getRowCount) + batch + } else { + reader.close(false) + allocator.close() + eos = true + null + } + } else { + dataStream.readInt() match { + case SpecialLengths.TIMING_DATA => + // Timing data from R worker + val boot = dataStream.readDouble - bootTime + val init = dataStream.readDouble + val broadcast = dataStream.readDouble + val input = dataStream.readDouble + val compute = dataStream.readDouble + val output = dataStream.readDouble + logInfo( + ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " + + "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " + + "total = %.3f s").format( + boot, + init, + broadcast, + input, + compute, + output, + boot + init + broadcast + input + compute + output)) + read() + case length if length > 0 => + // Likewise, there looks no way to send each batch in streaming format via socket + // connection. See ARROW-4512. + // So, it reads the whole Arrow streaming-formatted binary at once for now. + val in = new ByteArrayReadableSeekableByteChannel(readByteArrayData(length)) + reader = new ArrowStreamReader(in, allocator) + root = reader.getVectorSchemaRoot + vectors = root.getFieldVectors.asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + read() + case length if length == 0 => + eos = true + null + } + } + } catch { + case eof: EOFException => + throw new SparkException( + "R worker exited unexpectedly (crashed)\n " + errThread.getLines(), eof) + } + } + } +}