diff --git a/servicetalk-buffer-netty/src/main/java/io/servicetalk/buffer/netty/NettyBuffer.java b/servicetalk-buffer-netty/src/main/java/io/servicetalk/buffer/netty/NettyBuffer.java index 4a60f19ed3..b494666be8 100644 --- a/servicetalk-buffer-netty/src/main/java/io/servicetalk/buffer/netty/NettyBuffer.java +++ b/servicetalk-buffer-netty/src/main/java/io/servicetalk/buffer/netty/NettyBuffer.java @@ -377,14 +377,17 @@ public Buffer setBytes(int index, ByteBuffer src) { @Override public int setBytes(int index, InputStream src, int length) throws IOException { + if (length == 0) { + return 0; + } int totalWritten = 0; - int bytesWritten; + int bytesWritten = 0; while (length > 0 && (bytesWritten = buffer.setBytes(index, src, length)) >= 0) { totalWritten += bytesWritten; length -= bytesWritten; index += bytesWritten; } - return totalWritten; + return bytesWritten < 0 && totalWritten == 0 ? -1 : totalWritten; } @Override @@ -687,13 +690,16 @@ public Buffer writeBytes(ByteBuffer src) { @Override public int writeBytes(InputStream src, int length) throws IOException { + if (length == 0) { + return 0; + } int totalWritten = 0; - int bytesWritten; + int bytesWritten = 0; while (length > 0 && (bytesWritten = buffer.writeBytes(src, length)) >= 0) { totalWritten += bytesWritten; length -= bytesWritten; } - return totalWritten; + return bytesWritten < 0 && totalWritten == 0 ? -1 : totalWritten; } @Override diff --git a/servicetalk-buffer-netty/src/test/java/io/servicetalk/buffer/netty/NettyBufferTest.java b/servicetalk-buffer-netty/src/test/java/io/servicetalk/buffer/netty/NettyBufferTest.java index 09034b633a..410bb8e665 100644 --- a/servicetalk-buffer-netty/src/test/java/io/servicetalk/buffer/netty/NettyBufferTest.java +++ b/servicetalk-buffer-netty/src/test/java/io/servicetalk/buffer/netty/NettyBufferTest.java @@ -19,6 +19,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.ValueSource; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -29,19 +30,38 @@ import static io.servicetalk.buffer.netty.BufferAllocators.PREFER_DIRECT_ALLOCATOR; import static io.servicetalk.buffer.netty.BufferAllocators.PREFER_HEAP_ALLOCATOR; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; class NettyBufferTest { + @ParameterizedTest(name = "{displayName} [{index}] write={0}") + @ValueSource(booleans = {true, false}) + void writeBytesInputStreamZeroLength(boolean write) throws IOException { + Buffer buffer = buffer(true); + byte[] bytes = new byte[100]; + InputStream is = inputStream(bytes, false); + Buffer dup = buffer.duplicate(); + int readBytes; + if (write) { + readBytes = buffer.writeBytes(is, 0); + } else { + readBytes = buffer.setBytes(buffer.writerIndex(), is, 0); + } + assertThat("Read unexpected number of bytes", readBytes, is(0)); + assertThat("Unexpected changes for the buffer", buffer, equalTo(dup)); + } + @ParameterizedTest(name = "{displayName} [{index}] heapBuffer={0} limitRead={1} write={2}") @CsvSource(value = {"false,false,false", "false,false,true", "false,true,false", "false,true,true", "true,false,false", "true,false,true", "true,true,false", "true,true,true"}) - void writeBytesInputStream(boolean heapBuffer, boolean limitRead, boolean write) throws IOException { + void writeBytesInputStreamExactLength(boolean heapBuffer, boolean limitRead, boolean write) throws IOException { Buffer buffer = buffer(heapBuffer); byte[] bytes = new byte[100]; InputStream is = inputStream(bytes, limitRead); writeOrSetBytes(buffer, is, bytes.length, write); assertBytes(buffer, bytes, is, bytes.length); + assertEOF(buffer, is, write); } @ParameterizedTest(name = "{displayName} [{index}] heapBuffer={0} limitRead={1} write={2}") @@ -64,6 +84,7 @@ void writeBytesInputStreamDoubleLength(boolean heapBuffer, boolean limitRead, bo InputStream is = inputStream(bytes, limitRead); writeOrSetBytes(buffer, is, bytes.length * 2, write); assertBytes(buffer, bytes, is, bytes.length); + assertEOF(buffer, is, write); } private static void writeOrSetBytes(Buffer buffer, InputStream is, int length, boolean write) throws IOException { @@ -94,6 +115,7 @@ void writeBytesUntilEndStream(boolean heapBuffer, boolean limitRead, boolean wri buffer.writerIndex(buffer.writerIndex() + written); } assertBytes(buffer, bytes, is, bytes.length); + assertEOF(buffer, is, write); } private static Buffer buffer(boolean heapBuffer) { @@ -114,6 +136,15 @@ private static void assertBytes(Buffer buffer, byte[] bytes, InputStream is, int assertThat("Unexpected available bytes", is.available(), is(bytes.length - length)); } + private static void assertEOF(Buffer buffer, InputStream is, boolean write) throws IOException { + assertThat("Unexpected data from InputStream", is.read(), is(-1)); + if (write) { + assertThat("No EOF signal", buffer.writeBytes(is, 1), is(-1)); + } else { + assertThat("No EOF signal", buffer.setBytes(buffer.writerIndex(), is, 1), is(-1)); + } + } + private static final class TestInputStream extends InputStream { private final InputStream delegate;