Skip to content

Commit

Permalink
Fix poor buffering case for MultipartReader (square#8665)
Browse files Browse the repository at this point in the history
* Demonstrate poor buffering case
* Fix for repeated reads of small byteCount from large part

(cherry picked from commit 3cc87c3)
  • Loading branch information
yschimke committed Feb 17, 2025
1 parent b2f22c2 commit d36d006
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 5 deletions.
2 changes: 1 addition & 1 deletion okhttp-hpacktests/src/test/resources/hpack-test-case
Submodule hpack-test-case updated 510 files
13 changes: 9 additions & 4 deletions okhttp/src/main/kotlin/okhttp3/MultipartReader.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import java.io.Closeable
import java.io.IOException
import java.net.ProtocolException
import okhttp3.internal.http1.HeadersReader
import okhttp3.internal.limit
import okio.Buffer
import okio.BufferedSource
import okio.ByteString.Companion.encodeUtf8
Expand Down Expand Up @@ -175,10 +176,14 @@ class MultipartReader @Throws(IOException::class) constructor(
* one byte left to read.
*/
private fun currentPartBytesRemaining(maxResult: Long): Long {
source.require(crlfDashDashBoundary.size.toLong())

return when (val delimiterIndex = source.buffer.indexOf(crlfDashDashBoundary)) {
-1L -> minOf(maxResult, source.buffer.size - crlfDashDashBoundary.size + 1)
// Avoid indexOf scanning repeatedly over the entire source by using limit
// Since maxResult could be midway through the boundary, read further to be safe.
val limitSource = source.peek().limit(maxResult + crlfDashDashBoundary.size).buffer()
limitSource.require(crlfDashDashBoundary.size.toLong())

val delimiterIndex = limitSource.buffer.indexOf(crlfDashDashBoundary)
return when (delimiterIndex) {
-1L -> minOf(maxResult, limitSource.buffer.size - crlfDashDashBoundary.size + 1)
else -> minOf(maxResult, delimiterIndex)
}
}
Expand Down
80 changes: 80 additions & 0 deletions okhttp/src/main/kotlin/okhttp3/internal/FixedLengthSource.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright (C) 2024 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package okhttp3.internal

import kotlin.jvm.JvmOverloads
import okio.Buffer
import okio.ForwardingSource
import okio.Source

/**
* Return a new [Source] whose [read function][Source.read] returns -1 after [byteCount]
* bytes have been read.
*
* @param onReadExhausted Callback invoked once when the end of bytes has been reached. It receives
* `true` if the end of bytes was because the underlying stream did not contain enough bytes and
* `false` if [byteCount] bytes were successfully read.
*/
@JvmOverloads
internal fun Source.limit(
byteCount: Long,
onReadExhausted: (eof: Boolean) -> Unit = {},
): Source {
require(byteCount >= 0) { "byteCount < 0: $byteCount" }
return FixedLengthSource(this, byteCount, onReadExhausted, truncate = true)
}

internal class FixedLengthSource(
delegate: Source,
private var bytesRemaining: Long,
onReadExhausted: (eof: Boolean) -> Unit,
private val truncate: Boolean,
) : ForwardingSource(delegate) {
/** `null` once invoked. */
private var onReadExhausted: ((eof: Boolean) -> Unit)? = onReadExhausted

override fun read(
sink: Buffer,
byteCount: Long,
): Long {
val requestBytes =
if (truncate) {
if (bytesRemaining == 0L) {
// If the limit was 0 we want to wait until the first call to this function before
// triggering the callback.
onReadExhausted?.invoke(false)
onReadExhausted = null
return -1L
}
minOf(bytesRemaining, byteCount)
} else {
byteCount
}
val readBytes = super.read(sink, requestBytes)
if (readBytes == -1L) {
onReadExhausted!!(true)
onReadExhausted = null
return -1L
}
bytesRemaining -= readBytes
if (bytesRemaining == 0L) {
onReadExhausted!!(false)
onReadExhausted = null
}
return readBytes
}
}
50 changes: 50 additions & 0 deletions okhttp/src/test/java/okhttp3/MultipartReaderTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import java.io.EOFException
import java.net.ProtocolException
import okhttp3.Headers.Companion.headersOf
import okhttp3.MediaType.Companion.toMediaType
import okhttp3.MediaType.Companion.toMediaTypeOrNull
import okhttp3.RequestBody.Companion.toRequestBody
import okhttp3.ResponseBody.Companion.toResponseBody
import okio.Buffer
Expand Down Expand Up @@ -538,4 +539,53 @@ class MultipartReaderTest {

assertThat(reader.nextPart()).isNull()
}

@Test
fun `reading a large part with small byteCount`() {
val multipartBody: RequestBody =
MultipartBody.Builder("foo").addPart(
headersOf("header-name", "header-value"),
object : RequestBody() {
override fun contentType(): MediaType? {
return "application/octet-stream".toMediaTypeOrNull()
}

override fun contentLength(): Long {
return (1024 * 1024 * 100).toLong()
}

override fun writeTo(sink: okio.BufferedSink) {
repeat(100) {
sink.writeUtf8(
"a".repeat(1024 * 1024),
)
}
}
},
).build()
val buffer =
Buffer().apply {
multipartBody.writeTo(this)
}

val multipartReader = MultipartReader(buffer, "foo")
while (true) {
val part = multipartReader.nextPart()

if (part == null) break

assertThat(part.headers["header-name"]).isEqualTo("header-value")
while (true) {
val readBuff = Buffer()
val read = part.body.read(readBuff, (1024).toLong())
if (read == -1L) {
break
} else {
assertThat(readBuff.readUtf8()).isEqualTo(
"a".repeat(read.toInt()),
)
}
}
}
}
}

0 comments on commit d36d006

Please sign in to comment.