diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 29ca751519abd..3721c6c186815 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -23,12 +23,9 @@ import java.nio.charset.Charset import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.existentials -import net.razorvine.pickle.{Pickler, Unpickler} - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf} @@ -738,104 +735,6 @@ private[spark] object PythonRDD extends Logging { converted.saveAsHadoopDataset(new JobConf(conf)) } } - - - /** - * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). - */ - @deprecated("PySpark does not use it anymore", "1.1") - def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { - pyRDD.rdd.mapPartitions { iter => - val unpickle = new Unpickler - SerDeUtil.initialize() - iter.flatMap { row => - unpickle.loads(row) match { - // in case of objects are pickled in batch mode - case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap) - // not in batch mode - case obj: JMap[String @unchecked, _] => Seq(obj.toMap) - } - } - } - } - - /** - * Convert an RDD of serialized Python tuple to Array (no recursive conversions). - * It is only used by pyspark.sql. - */ - def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = { - - def toArray(obj: Any): Array[_] = { - obj match { - case objs: JArrayList[_] => - objs.toArray - case obj if obj.getClass.isArray => - obj.asInstanceOf[Array[_]].toArray - } - } - - pyRDD.rdd.mapPartitions { iter => - val unpickle = new Unpickler - iter.flatMap { row => - val obj = unpickle.loads(row) - if (batched) { - obj.asInstanceOf[JArrayList[_]].map(toArray) - } else { - Seq(toArray(obj)) - } - } - }.toJavaRDD() - } - - private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { - private val pickle = new Pickler() - private var batch = 1 - private val buffer = new mutable.ArrayBuffer[Any] - - override def hasNext(): Boolean = iter.hasNext - - override def next(): Array[Byte] = { - while (iter.hasNext && buffer.length < batch) { - buffer += iter.next() - } - val bytes = pickle.dumps(buffer.toArray) - val size = bytes.length - // let 1M < size < 10M - if (size < 1024 * 1024) { - batch *= 2 - } else if (size > 1024 * 1024 * 10 && batch > 1) { - batch /= 2 - } - buffer.clear() - bytes - } - } - - /** - * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by - * PySpark. - */ - def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { - jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) } - } - - /** - * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. - */ - def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { - pyRDD.rdd.mapPartitions { iter => - SerDeUtil.initialize() - val unpickle = new Unpickler - iter.flatMap { row => - val obj = unpickle.loads(row) - if (batched) { - obj.asInstanceOf[JArrayList[_]].asScala - } else { - Seq(obj) - } - } - }.toJavaRDD() - } } private diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index ebdc3533e0992..7355c7550c632 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -18,8 +18,13 @@ package org.apache.spark.api.python import java.nio.ByteOrder +import java.util.{ArrayList => JArrayList} + +import org.apache.spark.api.java.JavaRDD import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.Failure import scala.util.Try @@ -89,6 +94,73 @@ private[spark] object SerDeUtil extends Logging { } initialize() + + /** + * Convert an RDD of Java objects to Array (no recursive conversions). + * It is only used by pyspark.sql. + */ + private[python] def toJavaArray(jrdd: JavaRDD[_]): JavaRDD[Array[_]] = { + jrdd.rdd.map { + case objs: JArrayList[_] => + objs.toArray + case obj if obj.getClass.isArray => + obj.asInstanceOf[Array[_]].toArray + }.toJavaRDD() + } + + /** + * Choose batch size based on size of objects + */ + private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { + private val pickle = new Pickler() + private var batch = 1 + private val buffer = new mutable.ArrayBuffer[Any] + + override def hasNext: Boolean = iter.hasNext + + override def next(): Array[Byte] = { + while (iter.hasNext && buffer.length < batch) { + buffer += iter.next() + } + val bytes = pickle.dumps(buffer.toArray) + val size = bytes.length + // let 1M < size < 10M + if (size < 1024 * 1024) { + batch *= 2 + } else if (size > 1024 * 1024 * 10 && batch > 1) { + batch /= 2 + } + buffer.clear() + bytes + } + } + + /** + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by + * PySpark. + */ + def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = { + jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) } + } + + /** + * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. + */ + def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { + pyRDD.rdd.mapPartitions { iter => + initialize() + val unpickle = new Unpickler + iter.flatMap { row => + val obj = unpickle.loads(row) + if (batched) { + obj.asInstanceOf[JArrayList[_]].asScala + } else { + Seq(obj) + } + } + }.toJavaRDD() + } + private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = { val pickle = new Pickler val kt = Try { @@ -148,7 +220,7 @@ private[spark] object SerDeUtil extends Logging { */ def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = { def isPair(obj: Any): Boolean = { - Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) && + Option(obj.getClass.getComponentType).exists(!_.isPrimitive) && obj.asInstanceOf[Array[_]].length == 2 } pyRDD.mapPartitions { iter => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index b478c21537c2a..f3338c774f6ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -685,7 +685,7 @@ private[spark] object SerDe extends Serializable { def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { jRDD.rdd.mapPartitions { iter => initialize() // let it called in executor - new PythonRDD.AutoBatchedPickler(iter) + new SerDeUtil.AutoBatchedPickler(iter) } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 593542e5984c8..222191a64f9e1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1938,7 +1938,7 @@ def _to_java_object_rdd(self): RDD is serialized in batch or not. """ rdd = self._pickled() - return self.ctx._jvm.PythonRDD.pythonToJava(rdd._jrdd, True) + return self.ctx._jvm.SerDeUtil.pythonToJava(rdd._jrdd, True) def countApprox(self, timeout, confidence=0.95): """ diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index c90b49a953cee..79916db6ccdaf 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -938,7 +938,6 @@ def __init__(self, sparkContext, sqlContext=None): self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm - self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray self._scala_SQLContext = sqlContext @property @@ -1124,8 +1123,7 @@ def applySchema(self, rdd, schema): for row in rows: _verify_type(row, schema) - batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) - jrdd = self._pythonToJava(rdd._jrdd, batched) + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) return SchemaRDD(srdd.toJavaSchemaRDD(), self) @@ -1381,26 +1379,26 @@ class LocalHiveContext(HiveContext): An in-process metadata data is created with data stored in ./metadata. Warehouse data is stored in in ./warehouse. - >>> import os - >>> hiveCtx = LocalHiveContext(sc) - >>> try: - ... supress = hiveCtx.sql("DROP TABLE src") - ... except Exception: - ... pass - >>> kv1 = os.path.join(os.environ["SPARK_HOME"], - ... 'examples/src/main/resources/kv1.txt') - >>> supress = hiveCtx.sql( - ... "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - >>> supress = hiveCtx.sql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" - ... % kv1) - >>> results = hiveCtx.sql("FROM src SELECT value" - ... ).map(lambda r: int(r.value.split('_')[1])) - >>> num = results.count() - >>> reduce_sum = results.reduce(lambda x, y: x + y) - >>> num - 500 - >>> reduce_sum - 130091 + # >>> import os + # >>> hiveCtx = LocalHiveContext(sc) + # >>> try: + # ... supress = hiveCtx.sql("DROP TABLE src") + # ... except Exception: + # ... pass + # >>> kv1 = os.path.join(os.environ["SPARK_HOME"], + # ... 'examples/src/main/resources/kv1.txt') + # >>> supress = hiveCtx.sql( + # ... "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + # >>> supress = hiveCtx.sql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" + # ... % kv1) + # >>> results = hiveCtx.sql("FROM src SELECT value" + # ... ).map(lambda r: int(r.value.split('_')[1])) + # >>> num = results.count() + # >>> reduce_sum = results.reduce(lambda x, y: x + y) + # >>> num + # 500 + # >>> reduce_sum + # 130091 """ def __init__(self, sparkContext, sqlContext=None): @@ -1771,7 +1769,6 @@ def subtract(self, other, numPartitions=None): def _test(): import doctest - from array import array from pyspark.context import SparkContext # let doctest run in pyspark.sql, so DataTypes can be picklable import pyspark.sql diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 948122d42f0e1..539ef1134247e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.util.{Map => JMap, List => JList} +import org.apache.spark.api.python.SerDeUtil import org.apache.spark.storage.StorageLevel import scala.collection.JavaConversions._ @@ -414,12 +415,8 @@ class SchemaRDD( */ private[sql] def javaToPython: JavaRDD[Array[Byte]] = { val rowSchema = StructType.fromAttributes(this.queryExecution.analyzed.output) - this.mapPartitions { iter => - val pickle = new Pickler - iter.map { row => - rowToJArray(row, rowSchema) - }.grouped(100).map(batched => pickle.dumps(batched.toArray)) - } + val jrdd = this.map(rowToJArray(_, rowSchema)).toJavaRDD() + SerDeUtil.javaToPython(jrdd) } /**