From 2c4796ff0bc3767c0d42753448342352cc84d33e Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Wed, 18 Jan 2023 13:00:09 -0800 Subject: [PATCH] core: Free unused MessageProducer in RetriableStream This prevents leaking message buffers. Fixes #9563 --- .../io/grpc/internal/RetriableStream.java | 1 + .../io/grpc/internal/RetriableStreamTest.java | 41 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index 46da45aebd4..084c367db3f 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -1072,6 +1072,7 @@ public void messagesAvailable(final MessageProducer producer) { checkState( savedState.winningSubstream != null, "Headers should be received prior to messages."); if (savedState.winningSubstream != substream) { + GrpcUtil.closeQuietly(producer); return; } listenerSerializeExecutor.execute( diff --git a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java index 12bf697027c..6f43589528c 100644 --- a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java +++ b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java @@ -60,6 +60,8 @@ import io.grpc.internal.StreamListener.MessageProducer; import java.io.InputStream; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.Random; import java.util.concurrent.Executor; @@ -997,6 +999,27 @@ public void messageAvailable() { verify(masterListener).messagesAvailable(messageProducer); } + @Test + public void inboundMessagesClosedOnCancel() throws Exception { + ClientStream mockStream1 = mock(ClientStream.class); + doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); + + retriableStream.start(masterListener); + retriableStream.request(1); + retriableStream.cancel(Status.CANCELLED.withDescription("on purpose")); + + ArgumentCaptor sublistenerCaptor1 = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream1).start(sublistenerCaptor1.capture()); + + ClientStreamListener listener = sublistenerCaptor1.getValue(); + listener.headersRead(new Metadata()); + InputStream is = mock(InputStream.class); + listener.messagesAvailable(new FakeMessageProducer(is)); + verify(masterListener, never()).messagesAvailable(any(MessageProducer.class)); + verify(is).close(); + } + @Test public void closedWhileDraining() { ClientStream mockStream1 = mock(ClientStream.class); @@ -2723,4 +2746,22 @@ private interface RetriableStreamRecorder { Status prestart(); } + + private static final class FakeMessageProducer implements MessageProducer { + private final Iterator iterator; + + public FakeMessageProducer(InputStream... iss) { + this.iterator = Arrays.asList(iss).iterator(); + } + + @Override + @Nullable + public InputStream next() { + if (iterator.hasNext()) { + return iterator.next(); + } else { + return null; + } + } + } }