Skip to content

Commit

Permalink
Guard against use-after-close in DirectByteBufferOutputStream
Browse files Browse the repository at this point in the history
  • Loading branch information
ankurdave committed Nov 8, 2024
1 parent 83b8add commit 0ffe28a
Showing 1 changed file with 19 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ private[spark] class DirectByteBufferOutputStream(capacity: Int) extends OutputS
def this() = this(32)

override def write(b: Int): Unit = {
checkNotClosed()
ensureCapacity(buffer.position() + 1)
buffer.put(b.toByte)
}

override def write(b: Array[Byte], off: Int, len: Int): Unit = {
checkNotClosed()
ensureCapacity(buffer.position() + len)
buffer.put(b, off, len)
}
Expand All @@ -63,15 +65,29 @@ private[spark] class DirectByteBufferOutputStream(capacity: Int) extends OutputS
buffer = newBuffer
}

def reset(): Unit = buffer.clear()
private def checkNotClosed(): Unit = {
if (buffer == null) {
throw new IllegalStateException(
"Cannot call methods on a closed DirectByteBufferOutputStream")
}
}

def reset(): Unit = {
checkNotClosed()
buffer.clear()
}

def size(): Int = buffer.position()
def size(): Int = {
checkNotClosed()
buffer.position()
}

/**
* Any subsequent call to [[close()]], [[write()]], [[reset()]] will invalidate the buffer
* returned by this method.
*/
def toByteBuffer: ByteBuffer = {
checkNotClosed()
val outputBuffer = buffer.duplicate()
outputBuffer.flip()
outputBuffer
Expand All @@ -80,6 +96,7 @@ private[spark] class DirectByteBufferOutputStream(capacity: Int) extends OutputS
override def close(): Unit = {
// Eagerly free the direct byte buffer without waiting for GC to reduce memory pressure.
StorageUtils.dispose(buffer)
buffer = null
}

}

0 comments on commit 0ffe28a

Please sign in to comment.