diff --git a/flink-connectors/flink-connector-gcp-pubsub/src/main/java/org/apache/flink/streaming/connectors/gcp/pubsub/PubSubSource.java b/flink-connectors/flink-connector-gcp-pubsub/src/main/java/org/apache/flink/streaming/connectors/gcp/pubsub/PubSubSource.java index 4ddd8160cef90..1c7baaf1926ed 100644 --- a/flink-connectors/flink-connector-gcp-pubsub/src/main/java/org/apache/flink/streaming/connectors/gcp/pubsub/PubSubSource.java +++ b/flink-connectors/flink-connector-gcp-pubsub/src/main/java/org/apache/flink/streaming/connectors/gcp/pubsub/PubSubSource.java @@ -18,6 +18,8 @@ package org.apache.flink.streaming.connectors.gcp.pubsub; import org.apache.flink.api.common.functions.RuntimeContext; +import org.apache.flink.api.common.io.ratelimiting.FlinkConnectorRateLimiter; +import org.apache.flink.api.common.io.ratelimiting.GuavaFlinkConnectorRateLimiter; import org.apache.flink.api.common.serialization.DeserializationSchema; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.typeutils.ResultTypeQueryable; @@ -40,8 +42,6 @@ import com.google.pubsub.v1.ProjectSubscriptionName; import com.google.pubsub.v1.PubsubMessage; import com.google.pubsub.v1.ReceivedMessage; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.Serializable; @@ -57,13 +57,12 @@ */ public class PubSubSource extends RichSourceFunction implements ResultTypeQueryable, ParallelSourceFunction, CheckpointListener, ListCheckpointed> { - public static final int NO_MAX_MESSAGES_TO_ACKNOWLEDGE_LIMIT = -1; - private static final Logger LOG = LoggerFactory.getLogger(PubSubSource.class); protected final PubSubDeserializationSchema deserializationSchema; protected final PubSubSubscriberFactory pubSubSubscriberFactory; protected final Credentials credentials; - protected final int maxMessagesToAcknowledge; protected final AcknowledgeOnCheckpointFactory acknowledgeOnCheckpointFactory; + protected final FlinkConnectorRateLimiter rateLimiter; + protected final int messagePerSecondRateLimit; protected transient AcknowledgeOnCheckpoint acknowledgeOnCheckpoint; protected transient PubSubSubscriber subscriber; @@ -73,13 +72,15 @@ public class PubSubSource extends RichSourceFunction PubSubSource(PubSubDeserializationSchema deserializationSchema, PubSubSubscriberFactory pubSubSubscriberFactory, Credentials credentials, - int maxMessagesToAcknowledge, - AcknowledgeOnCheckpointFactory acknowledgeOnCheckpointFactory) { + AcknowledgeOnCheckpointFactory acknowledgeOnCheckpointFactory, + FlinkConnectorRateLimiter rateLimiter, + int messagePerSecondRateLimit) { this.deserializationSchema = deserializationSchema; this.pubSubSubscriberFactory = pubSubSubscriberFactory; this.credentials = credentials; - this.maxMessagesToAcknowledge = maxMessagesToAcknowledge; this.acknowledgeOnCheckpointFactory = acknowledgeOnCheckpointFactory; + this.rateLimiter = rateLimiter; + this.messagePerSecondRateLimit = messagePerSecondRateLimit; } @Override @@ -92,6 +93,10 @@ public void open(Configuration configuration) throws Exception { getRuntimeContext().getMetricGroup().gauge("PubSubMessagesProcessedNotAcked", this::getOutstandingMessagesToAck); + //convert per-subtask-limit to global rate limit, as FlinkConnectorRateLimiter::setRate expects a global rate limit. + rateLimiter.setRate(messagePerSecondRateLimit * getRuntimeContext().getNumberOfParallelSubtasks()); + rateLimiter.open(getRuntimeContext()); + createAndSetPubSubSubscriber(); this.isRunning = true; } @@ -104,11 +109,6 @@ private boolean hasNoCheckpointingEnabled(RuntimeContext runtimeContext) { public void run(SourceContext sourceContext) throws Exception { while (isRunning) { try { - if (maxMessagesToAcknowledgeLimitReached()) { - LOG.debug("Sleeping because there are {} messages waiting to be ack'ed but limit is {}", getOutstandingMessagesToAck(), maxMessagesToAcknowledge); - Thread.sleep(100); - continue; - } processMessage(sourceContext, subscriber.pull()); } catch (InterruptedException | CancellationException e) { @@ -119,6 +119,8 @@ public void run(SourceContext sourceContext) throws Exception { } void processMessage(SourceContext sourceContext, List messages) throws Exception { + rateLimiter.acquire(messages.size()); + synchronized (sourceContext.getCheckpointLock()) { for (ReceivedMessage message : messages) { acknowledgeOnCheckpoint.addAcknowledgeId(message.getAckId()); @@ -137,10 +139,6 @@ void processMessage(SourceContext sourceContext, List mess } } - private boolean maxMessagesToAcknowledgeLimitReached() throws Exception { - return maxMessagesToAcknowledge != NO_MAX_MESSAGES_TO_ACKNOWLEDGE_LIMIT && getOutstandingMessagesToAck() > maxMessagesToAcknowledge; - } - private Integer getOutstandingMessagesToAck() { return acknowledgeOnCheckpoint.numberOfOutstandingAcknowledgements(); } @@ -197,6 +195,7 @@ public static class PubSubSourceBuilder implements ProjectNameBuilder, private PubSubSubscriberFactory pubSubSubscriberFactory; private Credentials credentials; private int maxMessageToAcknowledge = 10000; + private int messagePerSecondRateLimit = 100000; private PubSubSourceBuilder(DeserializationSchema deserializationSchema) { Preconditions.checkNotNull(deserializationSchema); @@ -264,12 +263,13 @@ public PubSubSourceBuilder withPubSubSubscriberFactory(int maxMessagesPerPu } /** - * Set a limit of the number of outstanding or to-be acknowledged messages. - * default is 10000. Adjust this if you have high checkpoint intervals and / or run into memory issues - * due to the amount of acknowledgement ids. Use {@link PubSubSource}.NO_MAX_MESSAGES_TO_ACKNOWLEDGE_LIMIT if you want to remove the limit. + * Set a limit on the rate of messages per second received. This limit is per parallel instance of the source function. + * Default is set to 100000 messages per second. + * + * @param messagePerSecondRateLimit the message per second rate limit. */ - public PubSubSourceBuilder withMaxMessageToAcknowledge(int maxMessageToAcknowledge) { - this.maxMessageToAcknowledge = maxMessageToAcknowledge; + public PubSubSourceBuilder withMessageRateLimit(int messagePerSecondRateLimit) { + this.messagePerSecondRateLimit = messagePerSecondRateLimit; return this; } @@ -292,7 +292,7 @@ public PubSubSource build() throws IOException { 100); } - return new PubSubSource<>(deserializationSchema, pubSubSubscriberFactory, credentials, maxMessageToAcknowledge, new AcknowledgeOnCheckpointFactory()); + return new PubSubSource<>(deserializationSchema, pubSubSubscriberFactory, credentials, new AcknowledgeOnCheckpointFactory(), new GuavaFlinkConnectorRateLimiter(), messagePerSecondRateLimit); } } diff --git a/flink-connectors/flink-connector-gcp-pubsub/src/test/java/org/apache/flink/streaming/connectors/gcp/pubsub/PubSubSourceTest.java b/flink-connectors/flink-connector-gcp-pubsub/src/test/java/org/apache/flink/streaming/connectors/gcp/pubsub/PubSubSourceTest.java index 943f07524218c..bb6f0c366ac0a 100644 --- a/flink-connectors/flink-connector-gcp-pubsub/src/test/java/org/apache/flink/streaming/connectors/gcp/pubsub/PubSubSourceTest.java +++ b/flink-connectors/flink-connector-gcp-pubsub/src/test/java/org/apache/flink/streaming/connectors/gcp/pubsub/PubSubSourceTest.java @@ -17,6 +17,7 @@ package org.apache.flink.streaming.connectors.gcp.pubsub; +import org.apache.flink.api.common.io.ratelimiting.FlinkConnectorRateLimiter; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.metrics.MetricGroup; import org.apache.flink.streaming.api.functions.source.SourceFunction; @@ -78,6 +79,8 @@ public class PubSubSourceTest { private Credentials credentials; @Mock private PubSubSubscriber pubsubSubscriber; + @Mock + private FlinkConnectorRateLimiter rateLimiter; private PubSubSource pubSubSource; @@ -91,8 +94,9 @@ public void setup() throws Exception { pubSubSource = new PubSubSource<>(deserializationSchema, pubSubSubscriberFactory, credentials, - 100, - acknowledgeOnCheckpointFactory); + acknowledgeOnCheckpointFactory, + rateLimiter, + 1024); pubSubSource.setRuntimeContext(streamingRuntimeContext); } @@ -120,10 +124,15 @@ public void testProcessMessage() throws Exception { when(sourceContext.getCheckpointLock()).thenReturn("some object to lock on"); pubSubSource.open(null); - pubSubSource.processMessage(sourceContext, asList(receivedMessage("firstAckId", pubSubMessage(FIRST_MESSAGE)), - receivedMessage("secondAckId", pubSubMessage(SECOND_MESSAGE)))); + List receivedMessages = asList( + receivedMessage("firstAckId", pubSubMessage(FIRST_MESSAGE)), + receivedMessage("secondAckId", pubSubMessage(SECOND_MESSAGE)) + ); + pubSubSource.processMessage(sourceContext, receivedMessages); //verify handling messages + verify(rateLimiter, times(1)).acquire(2); + verify(sourceContext, times(1)).getCheckpointLock(); verify(deserializationSchema, times(1)).isEndOfStream(FIRST_MESSAGE); verify(deserializationSchema, times(1)).deserialize(pubSubMessage(FIRST_MESSAGE)); diff --git a/flink-examples/flink-examples-build-helper/flink-examples-streaming-gcp-pubsub/src/main/java/org/apache/flink/streaming/examples/gcp/pubsub/PubSubExample.java b/flink-examples/flink-examples-build-helper/flink-examples-streaming-gcp-pubsub/src/main/java/org/apache/flink/streaming/examples/gcp/pubsub/PubSubExample.java index a9601768b45a7..b79c67ee99438 100644 --- a/flink-examples/flink-examples-build-helper/flink-examples-streaming-gcp-pubsub/src/main/java/org/apache/flink/streaming/examples/gcp/pubsub/PubSubExample.java +++ b/flink-examples/flink-examples-build-helper/flink-examples-streaming-gcp-pubsub/src/main/java/org/apache/flink/streaming/examples/gcp/pubsub/PubSubExample.java @@ -67,6 +67,7 @@ private static void runFlinkJob(String projectName, String subscriptionName, Str .withDeserializationSchema(new IntegerSerializer()) .withProjectName(projectName) .withSubscriptionName(subscriptionName) + .withMessageRateLimit(1) .build()) .map(PubSubExample::printAndReturn).disableChaining() .addSink(PubSubSink.newBuilder()