diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 6f5d3dda377de..ff5f7a0e0d3fc 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -372,10 +372,8 @@ private[spark] object PythonRDD extends Logging { } def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { - // The right way to implement this would be to use TypeTags to get the full - // type of T. Since I don't want to introduce breaking changes throughout the - // entire Spark API, I have to use this hacky approach: - def write(bytes: Array[Byte]) { + + def writeBytes(bytes: Array[Byte]) { if (bytes == null) { dataOut.writeInt(SpecialLengths.NULL) } else { @@ -384,7 +382,7 @@ private[spark] object PythonRDD extends Logging { } } - def writeS(str: String) { + def writeString(str: String) { if (str == null) { dataOut.writeInt(SpecialLengths.NULL) } else { @@ -392,42 +390,45 @@ private[spark] object PythonRDD extends Logging { } } + // The right way to implement this would be to use TypeTags to get the full + // type of T. Since I don't want to introduce breaking changes throughout the + // entire Spark API, I have to use this hacky approach: if (iter.hasNext) { val first = iter.next() val newIter = Seq(first).iterator ++ iter first match { case arr: Array[Byte] => - newIter.asInstanceOf[Iterator[Array[Byte]]].foreach(write) + newIter.asInstanceOf[Iterator[Array[Byte]]].foreach(writeBytes) case string: String => - newIter.asInstanceOf[Iterator[String]].foreach(writeS) + newIter.asInstanceOf[Iterator[String]].foreach(writeString) case stream: PortableDataStream => newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream => - write(stream.toArray()) + writeBytes(stream.toArray()) } case (key: String, stream: PortableDataStream) => newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach { case (key, stream) => - writeS(key) - write(stream.toArray()) + writeString(key) + writeBytes(stream.toArray()) } case (key: String, value: String) => newIter.asInstanceOf[Iterator[(String, String)]].foreach { case (key, value) => - writeS(key) - writeS(value) + writeString(key) + writeString(value) } case (key: Array[Byte], value: Array[Byte]) => newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach { case (key, value) => - write(key) - write(value) + writeBytes(key) + writeBytes(value) } // key is null - case (null, v:Array[Byte]) => + case (null, value: Array[Byte]) => newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach { case (key, value) => - write(key) - write(value) + writeBytes(key) + writeBytes(value) } case other => diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index df95ce9622573..4c930d45ee251 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -487,7 +487,7 @@ def loads(self, stream): length = read_int(stream) if length == SpecialLengths.END_OF_DATA_SECTION: raise EOFError - if length == SpecialLengths.NULL: + elif length == SpecialLengths.NULL: return None s = stream.read(length) return s.decode("utf-8") if self.use_unicode else s