Skip to content

Commit

Permalink
Fix codes that uses ByteBuffer.array incorrectly
Browse files Browse the repository at this point in the history
ByteBuffer doesn't guarantee all contents in `ByteBuffer.array` are valid. E.g, a ByteBuffer returned by ByteBuffer.slice. We should not use the whole content of `ByteBuffer` unless we know that's correct.

This patch fixed all places that use `ByteBuffer.array` incorrectly.
  • Loading branch information
zsxwing committed Dec 2, 2015
1 parent 2cef1cd commit 93b68de
Show file tree
Hide file tree
Showing 21 changed files with 73 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap}
import org.apache.spark.network.server._
import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
import org.apache.spark.network.shuffle.protocol.UploadBlock
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -123,17 +124,10 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage

// StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded
// using our binary protocol.
val levelBytes = serializer.newInstance().serialize(level).array()
val levelBytes = JavaUtils.bufferToArray(serializer.newInstance().serialize(level))

// Convert or copy nio buffer into array in order to serialize it.
val nioBuffer = blockData.nioByteBuffer()
val array = if (nioBuffer.hasArray) {
nioBuffer.array()
} else {
val data = new Array[Byte](nioBuffer.remaining())
nioBuffer.get(data)
data
}
val array = JavaUtils.bufferToArray(blockData.nioByteBuffer())

client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteBuffer,
new RpcResponseCallback {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.commons.lang3.SerializationUtils
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.RpcTimeout
Expand Down Expand Up @@ -997,9 +998,10 @@ class DAGScheduler(
// For ResultTask, serialize and broadcast (rdd, func).
val taskBinaryBytes: Array[Byte] = stage match {
case stage: ShuffleMapStage =>
closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array()
JavaUtils.bufferToArray(
closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
case stage: ResultStage =>
closureSerializer.serialize((stage.rdd, stage.func): AnyRef).array()
JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
}

taskBinary = sc.broadcast(taskBinaryBytes)
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ private[spark] object Task {

// Write the task itself and finish
dataOut.flush()
val taskBytes = serializer.serialize(task).array()
out.write(taskBytes)
val taskBytes = serializer.serialize(task)
Utils.writeByteBuffer(taskBytes, out)
ByteBuffer.wrap(out.toByteArray)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.serializer

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.io.ByteArrayOutputStream
import java.nio.ByteBuffer

import scala.collection.mutable
Expand All @@ -31,6 +31,7 @@ import org.apache.commons.io.IOUtils

import org.apache.spark.{SparkException, SparkEnv}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.util.ByteBufferInputStream

/**
* Custom serializer used for generic Avro records. If the user registers the schemas
Expand Down Expand Up @@ -81,7 +82,7 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String])
* seen values so to limit the number of times that decompression has to be done.
*/
def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, {
val bis = new ByteArrayInputStream(schemaBytes.array())
val bis = new ByteBufferInputStream(schemaBytes)
val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis))
new Schema.Parser().parse(new String(bytes, "UTF-8"))
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.roaringbitmap.RoaringBitmap
import org.apache.spark._
import org.apache.spark.api.python.PythonBroadcast
import org.apache.spark.broadcast.HttpBroadcast
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.network.util.{ByteUnit, JavaUtils}
import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus}
import org.apache.spark.storage._
import org.apache.spark.util.collection.CompactBuffer
Expand Down Expand Up @@ -307,7 +307,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
val kryo = borrowKryo()
try {
input.setBuffer(bytes.array)
input.setBuffer(JavaUtils.bufferToArray(bytes))
kryo.readClassAndObject(input).asInstanceOf[T]
} finally {
releaseKryo(kryo)
Expand All @@ -319,7 +319,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
val oldClassLoader = kryo.getClassLoader
try {
kryo.setClassLoader(loader)
input.setBuffer(bytes.array)
input.setBuffer(JavaUtils.bufferToArray(bytes))
kryo.readClassAndObject(input).asInstanceOf[T]
} finally {
kryo.setClassLoader(oldClassLoader)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log
val file = getFile(blockId)
val os = file.getOutStream(WriteType.TRY_CACHE)
try {
os.write(bytes.array())
Utils.writeByteBuffer(bytes, os)
} catch {
case NonFatal(e) =>
logWarning(s"Failed to put bytes of block $blockId into Tachyon", e)
Expand Down
15 changes: 14 additions & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,20 @@ private[spark] object Utils extends Logging {
/**
* Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]]
*/
def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput): Unit = {
def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = {
if (bb.hasArray) {
out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
} else {
val bbval = new Array[Byte](bb.remaining())
bb.get(bbval)
out.write(bbval)
}
}

/**
* Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.OutputStream]]
*/
def writeByteBuffer(bb: ByteBuffer, out: OutputStream): Unit = {
if (bb.hasArray) {
out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.storage.*;
import org.apache.spark.unsafe.Platform;
Expand Down Expand Up @@ -430,7 +431,7 @@ public void randomizedStressTest() {
}

for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
final byte[] key = entry.getKey().array();
final byte[] key = JavaUtils.bufferToArray(entry.getKey());
final byte[] value = entry.getValue();
final BytesToBytesMap.Location loc =
map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
Expand Down Expand Up @@ -480,7 +481,7 @@ public void randomizedTestWithRecordsLargerThanPageSize() {
}
}
for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
final byte[] key = entry.getKey().array();
final byte[] key = JavaUtils.bufferToArray(entry.getKey());
final byte[] value = entry.getValue();
final BytesToBytesMap.Location loc =
map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.mockito.Matchers.any
import org.scalatest.BeforeAndAfter

import org.apache.spark._
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
import org.apache.spark.metrics.source.JvmSource
Expand Down Expand Up @@ -57,7 +58,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
}
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).array)
val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
val task = new ResultTask[String, String](
0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty)
intercept[RuntimeException] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ object AvroConversionUtil extends Serializable {

def unpackBytes(obj: Any): Array[Byte] = {
val bytes: Array[Byte] = obj match {
case buf: java.nio.ByteBuffer => buf.array()
case buf: java.nio.ByteBuffer =>
val arr = new Array[Byte](buf.remaining())
buf.get(arr)
arr
case arr: Array[Byte] => arr
case other => throw new SparkException(
s"Unknown BYTES type ${other.getClass.getName}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ class SparkFlumeEvent() extends Externalizable {

/* Serialize to bytes. */
def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
val body = event.getBody.array()
out.writeInt(body.length)
out.write(body)
val body = event.getBody
out.writeInt(body.remaining())
Utils.writeByteBuffer(body, out)

val numHeaders = event.getHeaders.size()
out.writeInt(numHeaders)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
import scala.concurrent.duration._
import scala.language.postfixOps

import com.google.common.base.Charsets.UTF_8
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually._

import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext}
Expand Down Expand Up @@ -119,7 +119,7 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Log
val headers = flattenOutputBuffer.map(_.event.getHeaders.asScala.map {
case (key, value) => (key.toString, value.toString)
}).map(_.asJava)
val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8))
val bodies = flattenOutputBuffer.map(e => JavaUtils.bytesToString(e.event.getBody))
utils.assertOutput(headers.asJava, bodies.asJava)
}
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import scala.concurrent.duration._
import scala.language.postfixOps

import com.google.common.base.Charsets
import org.jboss.netty.channel.ChannelPipeline
import org.jboss.netty.channel.socket.SocketChannel
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
Expand All @@ -31,6 +30,7 @@ import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._

import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream}

Expand Down Expand Up @@ -63,7 +63,7 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w
event =>
event.getHeaders.get("test") should be("header")
}
val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8))
val output = outputEvents.map(event => JavaUtils.bytesToString(event.getBody))
output should be (input)
}
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.scalatest.Matchers._
import org.scalatest.concurrent.Eventually
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}

import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming._
Expand Down Expand Up @@ -196,7 +197,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun

testIfEnabled("custom message handling") {
val awsCredentials = KinesisTestUtils.getAWSCredentials()
def addFive(r: Record): Int = new String(r.getData.array()).toInt + 5
def addFive(r: Record): Int = JavaUtils.bytesToString(r.getData).toInt + 5
val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName,
testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST,
Seconds(10), StorageLevel.MEMORY_ONLY, addFive,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,11 @@ private void decodeBinaryBatch(int col, int num) throws IOException {
ByteBuffer bytes = columnReaders[col].nextBinary().toByteBuffer();
int len = bytes.limit() - bytes.position();
if (originalTypes[col] == OriginalType.UTF8) {
UTF8String str = UTF8String.fromBytes(bytes.array(), bytes.position(), len);
UTF8String str =
UTF8String.fromBytes(bytes.array(), bytes.arrayOffset() + bytes.position(), len);
rowWriters[n].write(col, str);
} else {
rowWriters[n].write(col, bytes.array(), bytes.position(), len);
rowWriters[n].write(col, bytes.array(), bytes.arrayOffset() + bytes.position(), len);
}
rows[n].setNotNullAt(col);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Kryo, Serializer}
import com.twitter.chill.ResourcePool

import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.{KryoSerializer, SerializerInstance}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.util.MutablePair
Expand Down Expand Up @@ -76,7 +77,7 @@ private[sql] object SparkSqlSerializer {

def serialize[T: ClassTag](o: T): Array[Byte] =
acquireRelease { k =>
k.serialize(o).array()
JavaUtils.bufferToArray(k.serialize(o))
}

def deserialize[T: ClassTag](bytes: Array[Byte]): T =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
Expand Down Expand Up @@ -163,7 +164,9 @@ private[sql] case class InMemoryRelation(
.flatMap(_.values))

batchStats += stats
CachedBatch(rowCount, columnBuilders.map(_.build().array()), stats)
CachedBatch(rowCount, columnBuilders.map { builder =>
JavaUtils.bufferToArray(builder.build())
}, stats)
}

def hasNext: Boolean = rowIterator.hasNext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ private[parquet] class CatalystRowConverter(
// are using `Binary.toByteBuffer.array()` to steal the underlying byte array without copying
// it.
val buffer = value.toByteBuffer
val offset = buffer.position()
val numBytes = buffer.limit() - buffer.position()
val offset = buffer.arrayOffset() + buffer.position()
val numBytes = buffer.remaining()
updater.set(UTF8String.fromBytes(buffer.array(), offset, numBytes))
}
}
Expand Down Expand Up @@ -644,8 +644,8 @@ private[parquet] object CatalystRowConverter {
// copying it.
val buffer = binary.toByteBuffer
val bytes = buffer.array()
val start = buffer.position()
val end = buffer.limit()
val start = buffer.arrayOffset() + buffer.position()
val end = buffer.arrayOffset() + buffer.limit()

var unscaled = 0L
var i = start
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.network.util.JavaUtils
import org.apache.spark.streaming.Time
import org.apache.spark.streaming.util.{BatchedWriteAheadLog, WriteAheadLog, WriteAheadLogUtils}
import org.apache.spark.util.{Clock, Utils}
Expand Down Expand Up @@ -212,7 +213,7 @@ private[streaming] class ReceivedBlockTracker(
writeAheadLog.readAll().asScala.foreach { byteBuffer =>
logTrace("Recovering record " + byteBuffer)
Utils.deserialize[ReceivedBlockTrackerLogEvent](
byteBuffer.array, Thread.currentThread().getContextClassLoader) match {
JavaUtils.bufferToArray(byteBuffer), Thread.currentThread().getContextClassLoader) match {
case BlockAdditionEvent(receivedBlockInfo) =>
insertAddedBlock(receivedBlockInfo)
case BatchAllocationEvent(time, allocatedBlocks) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import scala.util.Try
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FSDataOutputStream

import org.apache.spark.util.Utils

/**
* A writer for writing byte-buffers to a write ahead log file.
*/
Expand All @@ -48,17 +50,7 @@ private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf:
val lengthToWrite = data.remaining()
val segment = new FileBasedWriteAheadLogSegment(path, nextOffset, lengthToWrite)
stream.writeInt(lengthToWrite)
if (data.hasArray) {
stream.write(data.array())
} else {
// If the buffer is not backed by an array, we transfer using temp array
// Note that despite the extra array copy, this should be faster than byte-by-byte copy
while (data.hasRemaining) {
val array = new Array[Byte](data.remaining)
data.get(array)
stream.write(array)
}
}
Utils.writeByteBuffer(data, stream: OutputStream)
flush()
nextOffset = stream.getPos()
segment
Expand Down
Loading

0 comments on commit 93b68de

Please sign in to comment.