Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12084][Core]Fix codes that uses ByteBuffer.array incorrectly #10083

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap}
import org.apache.spark.network.server._
import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
import org.apache.spark.network.shuffle.protocol.UploadBlock
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -123,17 +124,10 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage

// StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded
// using our binary protocol.
val levelBytes = serializer.newInstance().serialize(level).array()
val levelBytes = JavaUtils.bufferToArray(serializer.newInstance().serialize(level))

// Convert or copy nio buffer into array in order to serialize it.
val nioBuffer = blockData.nioByteBuffer()
val array = if (nioBuffer.hasArray) {
nioBuffer.array()
} else {
val data = new Array[Byte](nioBuffer.remaining())
nioBuffer.get(data)
data
}
val array = JavaUtils.bufferToArray(blockData.nioByteBuffer())

client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteBuffer,
new RpcResponseCallback {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.commons.lang3.SerializationUtils
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.RpcTimeout
Expand Down Expand Up @@ -997,9 +998,10 @@ class DAGScheduler(
// For ResultTask, serialize and broadcast (rdd, func).
val taskBinaryBytes: Array[Byte] = stage match {
case stage: ShuffleMapStage =>
closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef).array()
JavaUtils.bufferToArray(
closureSerializer.serialize((stage.rdd, stage.shuffleDep): AnyRef))
case stage: ResultStage =>
closureSerializer.serialize((stage.rdd, stage.func): AnyRef).array()
JavaUtils.bufferToArray(closureSerializer.serialize((stage.rdd, stage.func): AnyRef))
}

taskBinary = sc.broadcast(taskBinaryBytes)
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ private[spark] object Task {

// Write the task itself and finish
dataOut.flush()
val taskBytes = serializer.serialize(task).array()
out.write(taskBytes)
val taskBytes = serializer.serialize(task)
Utils.writeByteBuffer(taskBytes, out)
ByteBuffer.wrap(out.toByteArray)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String])
* seen values so to limit the number of times that decompression has to be done.
*/
def decompress(schemaBytes: ByteBuffer): Schema = decompressCache.getOrElseUpdate(schemaBytes, {
val bis = new ByteArrayInputStream(schemaBytes.array())
val bis = new ByteArrayInputStream(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because decompressCache puts ByteBuffer as a key, here should not change the schemaBytes's position, so cannot use ByteBufferInputStream here.

schemaBytes.array(),
schemaBytes.arrayOffset() + schemaBytes.position(),
schemaBytes.remaining())
val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis))
new Schema.Parser().parse(new String(bytes, "UTF-8"))
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
override def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
val kryo = borrowKryo()
try {
input.setBuffer(bytes.array)
input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary for this change, but at some point it might be worth it to change this to use Kryo's ByteBufferInput.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kryo will use the array as an internal buffer. Why it's not necessary?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm saying that the change I proposed is not necessary, not that your change is not necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm saying that the change I proposed is not necessary, not that your change is not necessary.

Got it. Sorry for my misunderstanding.

kryo.readClassAndObject(input).asInstanceOf[T]
} finally {
releaseKryo(kryo)
Expand All @@ -319,7 +319,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
val oldClassLoader = kryo.getClassLoader
try {
kryo.setClassLoader(loader)
input.setBuffer(bytes.array)
input.setBuffer(bytes.array(), bytes.arrayOffset() + bytes.position(), bytes.remaining())
kryo.readClassAndObject(input).asInstanceOf[T]
} finally {
kryo.setClassLoader(oldClassLoader)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log
val file = getFile(blockId)
val os = file.getOutStream(WriteType.TRY_CACHE)
try {
os.write(bytes.array())
Utils.writeByteBuffer(bytes, os)
} catch {
case NonFatal(e) =>
logWarning(s"Failed to put bytes of block $blockId into Tachyon", e)
Expand Down
15 changes: 14 additions & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,20 @@ private[spark] object Utils extends Logging {
/**
* Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]]
*/
def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput): Unit = {
def writeByteBuffer(bb: ByteBuffer, out: DataOutput): Unit = {
if (bb.hasArray) {
out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
} else {
val bbval = new Array[Byte](bb.remaining())
bb.get(bbval)
out.write(bbval)
}
}

/**
* Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.OutputStream]]
*/
def writeByteBuffer(bb: ByteBuffer, out: OutputStream): Unit = {
if (bb.hasArray) {
out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining())
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.TestMemoryManager;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.storage.*;
import org.apache.spark.unsafe.Platform;
Expand Down Expand Up @@ -430,7 +431,7 @@ public void randomizedStressTest() {
}

for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
final byte[] key = entry.getKey().array();
final byte[] key = JavaUtils.bufferToArray(entry.getKey());
final byte[] value = entry.getValue();
final BytesToBytesMap.Location loc =
map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
Expand Down Expand Up @@ -480,7 +481,7 @@ public void randomizedTestWithRecordsLargerThanPageSize() {
}
}
for (Map.Entry<ByteBuffer, byte[]> entry : expected.entrySet()) {
final byte[] key = entry.getKey().array();
final byte[] key = JavaUtils.bufferToArray(entry.getKey());
final byte[] value = entry.getValue();
final BytesToBytesMap.Location loc =
map.lookup(key, Platform.BYTE_ARRAY_OFFSET, key.length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.mockito.Matchers.any
import org.scalatest.BeforeAndAfter

import org.apache.spark._
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
import org.apache.spark.metrics.source.JvmSource
Expand Down Expand Up @@ -57,7 +58,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
}
val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val taskBinary = sc.broadcast(closureSerializer.serialize((rdd, func)).array)
val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func))))
val task = new ResultTask[String, String](
0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, Seq.empty)
intercept[RuntimeException] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ object AvroConversionUtil extends Serializable {

def unpackBytes(obj: Any): Array[Byte] = {
val bytes: Array[Byte] = obj match {
case buf: java.nio.ByteBuffer => buf.array()
case buf: java.nio.ByteBuffer =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't use bufferToArray here too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in examples. So I don't want to use private API.

val arr = new Array[Byte](buf.remaining())
buf.get(arr)
arr
case arr: Array[Byte] => arr
case other => throw new SparkException(
s"Unknown BYTES type ${other.getClass.getName}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ class SparkFlumeEvent() extends Externalizable {

/* Serialize to bytes. */
def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
val body = event.getBody.array()
out.writeInt(body.length)
out.write(body)
val body = event.getBody
out.writeInt(body.remaining())
Utils.writeByteBuffer(body, out)

val numHeaders = event.getHeaders.size()
out.writeInt(numHeaders)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
import scala.concurrent.duration._
import scala.language.postfixOps

import com.google.common.base.Charsets.UTF_8
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually._

import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext}
Expand Down Expand Up @@ -119,7 +119,7 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Log
val headers = flattenOutputBuffer.map(_.event.getHeaders.asScala.map {
case (key, value) => (key.toString, value.toString)
}).map(_.asJava)
val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8))
val bodies = flattenOutputBuffer.map(e => JavaUtils.bytesToString(e.event.getBody))
utils.assertOutput(headers.asJava, bodies.asJava)
}
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import scala.concurrent.duration._
import scala.language.postfixOps

import com.google.common.base.Charsets
import org.jboss.netty.channel.ChannelPipeline
import org.jboss.netty.channel.socket.SocketChannel
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
Expand All @@ -31,6 +30,7 @@ import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._

import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream}

Expand Down Expand Up @@ -63,7 +63,7 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w
event =>
event.getHeaders.get("test") should be("header")
}
val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8))
val output = outputEvents.map(event => JavaUtils.bytesToString(event.getBody))
output should be (input)
}
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.scalatest.Matchers._
import org.scalatest.concurrent.Eventually
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}

import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.{StorageLevel, StreamBlockId}
import org.apache.spark.streaming._
Expand Down Expand Up @@ -196,7 +197,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun

testIfEnabled("custom message handling") {
val awsCredentials = KinesisTestUtils.getAWSCredentials()
def addFive(r: Record): Int = new String(r.getData.array()).toInt + 5
def addFive(r: Record): Int = JavaUtils.bytesToString(r.getData).toInt + 5
val stream = KinesisUtils.createStream(ssc, appName, testUtils.streamName,
testUtils.endpointUrl, testUtils.regionName, InitialPositionInStream.LATEST,
Seconds(10), StorageLevel.MEMORY_ONLY, addFive,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,13 @@ private void decodeBinaryBatch(int col, int num) throws IOException {
for (int n = 0; n < num; ++n) {
if (columnReaders[col].next()) {
ByteBuffer bytes = columnReaders[col].nextBinary().toByteBuffer();
int len = bytes.limit() - bytes.position();
int len = bytes.remaining();
if (originalTypes[col] == OriginalType.UTF8) {
UTF8String str = UTF8String.fromBytes(bytes.array(), bytes.position(), len);
UTF8String str =
UTF8String.fromBytes(bytes.array(), bytes.arrayOffset() + bytes.position(), len);
rowWriters[n].write(col, str);
} else {
rowWriters[n].write(col, bytes.array(), bytes.position(), len);
rowWriters[n].write(col, bytes.array(), bytes.arrayOffset() + bytes.position(), len);
}
rows[n].setNotNullAt(col);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Kryo, Serializer}
import com.twitter.chill.ResourcePool

import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.{KryoSerializer, SerializerInstance}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.util.MutablePair
Expand Down Expand Up @@ -76,7 +77,7 @@ private[sql] object SparkSqlSerializer {

def serialize[T: ClassTag](o: T): Array[Byte] =
acquireRelease { k =>
k.serialize(o).array()
JavaUtils.bufferToArray(k.serialize(o))
}

def deserialize[T: ClassTag](bytes: Array[Byte]): T =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.network.util.JavaUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
Expand Down Expand Up @@ -163,7 +164,9 @@ private[sql] case class InMemoryRelation(
.flatMap(_.values))

batchStats += stats
CachedBatch(rowCount, columnBuilders.map(_.build().array()), stats)
CachedBatch(rowCount, columnBuilders.map { builder =>
JavaUtils.bufferToArray(builder.build())
}, stats)
}

def hasNext: Boolean = rowIterator.hasNext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ private[parquet] class CatalystRowConverter(
// are using `Binary.toByteBuffer.array()` to steal the underlying byte array without copying
// it.
val buffer = value.toByteBuffer
val offset = buffer.position()
val numBytes = buffer.limit() - buffer.position()
val offset = buffer.arrayOffset() + buffer.position()
val numBytes = buffer.remaining()
updater.set(UTF8String.fromBytes(buffer.array(), offset, numBytes))
}
}
Expand Down Expand Up @@ -644,8 +644,8 @@ private[parquet] object CatalystRowConverter {
// copying it.
val buffer = binary.toByteBuffer
val bytes = buffer.array()
val start = buffer.position()
val end = buffer.limit()
val start = buffer.arrayOffset() + buffer.position()
val end = buffer.arrayOffset() + buffer.limit()

var unscaled = 0L
var i = start
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.network.util.JavaUtils
import org.apache.spark.streaming.Time
import org.apache.spark.streaming.util.{BatchedWriteAheadLog, WriteAheadLog, WriteAheadLogUtils}
import org.apache.spark.util.{Clock, Utils}
Expand Down Expand Up @@ -212,7 +213,7 @@ private[streaming] class ReceivedBlockTracker(
writeAheadLog.readAll().asScala.foreach { byteBuffer =>
logTrace("Recovering record " + byteBuffer)
Utils.deserialize[ReceivedBlockTrackerLogEvent](
byteBuffer.array, Thread.currentThread().getContextClassLoader) match {
JavaUtils.bufferToArray(byteBuffer), Thread.currentThread().getContextClassLoader) match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth it to have a version of Utils.deserialize that takes a ByteBuffer, but not required for this change.

case BlockAdditionEvent(receivedBlockInfo) =>
insertAddedBlock(receivedBlockInfo)
case BatchAllocationEvent(time, allocatedBlocks) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import scala.util.Try
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FSDataOutputStream

import org.apache.spark.util.Utils

/**
* A writer for writing byte-buffers to a write ahead log file.
*/
Expand All @@ -48,17 +50,7 @@ private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf:
val lengthToWrite = data.remaining()
val segment = new FileBasedWriteAheadLogSegment(path, nextOffset, lengthToWrite)
stream.writeInt(lengthToWrite)
if (data.hasArray) {
stream.write(data.array())
} else {
// If the buffer is not backed by an array, we transfer using temp array
// Note that despite the extra array copy, this should be faster than byte-by-byte copy
while (data.hasRemaining) {
val array = new Array[Byte](data.remaining)
data.get(array)
stream.write(array)
}
}
Utils.writeByteBuffer(data, stream: OutputStream)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird, was the compiler complaining about something here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Because FSDataOutputStream implements both OutputStream and DataOutput.

flush()
nextOffset = stream.getPos()
segment
Expand Down
Loading