diff --git a/CHANGES.md b/CHANGES.md
index 7c377648c117..973ed762dfd0 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -65,6 +65,8 @@
* Added support for Zstd codec in SerializableAvroCodecFactory (Java) ([#32349](https://github.com/apache/beam/issues/32349))
* Added support for using vLLM in the RunInference transform (Python) ([#32528](https://github.com/apache/beam/issues/32528))
* Significantly improved performance of Kafka IO reads that enable [commitOffsetsInFinalize](https://beam.apache.org/releases/javadoc/current/org/apache/beam/sdk/io/kafka/KafkaIO.Read.html#commitOffsetsInFinalize--) by removing the data reshuffle from SDF implementation. ([#31682](https://github.com/apache/beam/pull/31682)).
+* Added support for dynamic writing in MqttIO (Java) ([#19376](https://github.com/apache/beam/issues/19376))
+* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
## Breaking Changes
diff --git a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
index 0e584d564b5c..e1868e2c8461 100644
--- a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
+++ b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
@@ -39,6 +39,8 @@
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.SerializableFunctions;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
@@ -99,6 +101,26 @@
* "my_topic"))
*
* }
+ *
+ *
Dynamic Writing to a MQTT Broker
+ *
+ * MqttIO also supports dynamic writing to multiple topics based on the data. You can specify a
+ * function to determine the target topic for each message. The following example demonstrates how
+ * to configure dynamic topic writing:
+ *
+ *
{@code
+ * pipeline
+ * .apply(...) // Provide PCollection
+ * .apply(
+ * MqttIO.dynamicWrite()
+ * .withConnectionConfiguration(
+ * MqttIO.ConnectionConfiguration.create("tcp://host:11883"))
+ * .withTopicFn()
+ * .withPayloadFn());
+ * }
+ *
+ * This dynamic writing capability allows for more flexible MQTT message routing based on the
+ * message content, enabling scenarios where messages are directed to different topics.
*/
@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
@@ -115,8 +137,16 @@ public static Read read() {
.build();
}
- public static Write write() {
- return new AutoValue_MqttIO_Write.Builder().setRetained(false).build();
+ public static Write write() {
+ return new AutoValue_MqttIO_Write.Builder()
+ .setRetained(false)
+ .setPayloadFn(SerializableFunctions.identity())
+ .setDynamic(false)
+ .build();
+ }
+
+ public static Write dynamicWrite() {
+ return new AutoValue_MqttIO_Write.Builder().setRetained(false).setDynamic(true).build();
}
private MqttIO() {}
@@ -127,7 +157,7 @@ public abstract static class ConnectionConfiguration implements Serializable {
abstract String getServerUri();
- abstract String getTopic();
+ abstract @Nullable String getTopic();
abstract @Nullable String getClientId();
@@ -169,6 +199,11 @@ public static ConnectionConfiguration create(String serverUri, String topic) {
.build();
}
+ public static ConnectionConfiguration create(String serverUri) {
+ checkArgument(serverUri != null, "serverUri can not be null");
+ return new AutoValue_MqttIO_ConnectionConfiguration.Builder().setServerUri(serverUri).build();
+ }
+
/** Set up the MQTT broker URI. */
public ConnectionConfiguration withServerUri(String serverUri) {
checkArgument(serverUri != null, "serverUri can not be null");
@@ -199,7 +234,7 @@ public ConnectionConfiguration withPassword(String password) {
private void populateDisplayData(DisplayData.Builder builder) {
builder.add(DisplayData.item("serverUri", getServerUri()));
- builder.add(DisplayData.item("topic", getTopic()));
+ builder.addIfNotNull(DisplayData.item("topic", getTopic()));
builder.addIfNotNull(DisplayData.item("clientId", getClientId()));
builder.addIfNotNull(DisplayData.item("username", getUsername()));
}
@@ -278,6 +313,9 @@ public Read withMaxReadTime(Duration maxReadTime) {
@Override
public PCollection expand(PBegin input) {
+ checkArgument(connectionConfiguration() != null, "connectionConfiguration can not be null");
+ checkArgument(connectionConfiguration().getTopic() != null, "topic can not be null");
+
org.apache.beam.sdk.io.Read.Unbounded unbounded =
org.apache.beam.sdk.io.Read.from(new UnboundedMqttSource(this));
@@ -505,29 +543,50 @@ public UnboundedMqttSource getCurrentSource() {
/** A {@link PTransform} to write and send a message to a MQTT server. */
@AutoValue
- public abstract static class Write extends PTransform, PDone> {
-
+ public abstract static class Write extends PTransform, PDone> {
abstract @Nullable ConnectionConfiguration connectionConfiguration();
+ abstract @Nullable SerializableFunction topicFn();
+
+ abstract @Nullable SerializableFunction payloadFn();
+
+ abstract boolean dynamic();
+
abstract boolean retained();
- abstract Builder builder();
+ abstract Builder builder();
@AutoValue.Builder
- abstract static class Builder {
- abstract Builder setConnectionConfiguration(ConnectionConfiguration configuration);
+ abstract static class Builder {
+ abstract Builder setConnectionConfiguration(ConnectionConfiguration configuration);
+
+ abstract Builder setRetained(boolean retained);
+
+ abstract Builder setTopicFn(SerializableFunction topicFn);
- abstract Builder setRetained(boolean retained);
+ abstract Builder setPayloadFn(SerializableFunction payloadFn);
- abstract Write build();
+ abstract Builder setDynamic(boolean dynamic);
+
+ abstract Write build();
}
/** Define MQTT connection configuration used to connect to the MQTT broker. */
- public Write withConnectionConfiguration(ConnectionConfiguration configuration) {
+ public Write withConnectionConfiguration(ConnectionConfiguration configuration) {
checkArgument(configuration != null, "configuration can not be null");
return builder().setConnectionConfiguration(configuration).build();
}
+ public Write withTopicFn(SerializableFunction topicFn) {
+ checkArgument(dynamic(), "withTopicFn can not use in non-dynamic write");
+ return builder().setTopicFn(topicFn).build();
+ }
+
+ public Write withPayloadFn(SerializableFunction payloadFn) {
+ checkArgument(dynamic(), "withPayloadFn can not use in non-dynamic write");
+ return builder().setPayloadFn(payloadFn).build();
+ }
+
/**
* Whether or not the publish message should be retained by the messaging engine. Sending a
* message with the retained set to {@code false} will clear the retained message from the
@@ -538,54 +597,76 @@ public Write withConnectionConfiguration(ConnectionConfiguration configuration)
* @param retained Whether or not the messaging engine should retain the message.
* @return The {@link Write} {@link PTransform} with the corresponding retained configuration.
*/
- public Write withRetained(boolean retained) {
+ public Write withRetained(boolean retained) {
return builder().setRetained(retained).build();
}
- @Override
- public PDone expand(PCollection input) {
- input.apply(ParDo.of(new WriteFn(this)));
- return PDone.in(input.getPipeline());
- }
-
@Override
public void populateDisplayData(DisplayData.Builder builder) {
connectionConfiguration().populateDisplayData(builder);
builder.add(DisplayData.item("retained", retained()));
}
- private static class WriteFn extends DoFn {
+ @Override
+ public PDone expand(PCollection input) {
+ checkArgument(connectionConfiguration() != null, "connectionConfiguration can not be null");
+ if (dynamic()) {
+ checkArgument(
+ connectionConfiguration().getTopic() == null, "DynamicWrite can not have static topic");
+ checkArgument(topicFn() != null, "topicFn can not be null");
+ } else {
+ checkArgument(connectionConfiguration().getTopic() != null, "topic can not be null");
+ }
+ checkArgument(payloadFn() != null, "payloadFn can not be null");
+
+ input.apply(ParDo.of(new WriteFn<>(this)));
+ return PDone.in(input.getPipeline());
+ }
+
+ private static class WriteFn extends DoFn {
- private final Write spec;
+ private final Write spec;
+ private final SerializableFunction topicFn;
+ private final SerializableFunction payloadFn;
+ private final boolean retained;
private transient MQTT client;
private transient BlockingConnection connection;
- public WriteFn(Write spec) {
+ public WriteFn(Write spec) {
this.spec = spec;
+ if (spec.dynamic()) {
+ this.topicFn = spec.topicFn();
+ } else {
+ String topic = spec.connectionConfiguration().getTopic();
+ this.topicFn = ignore -> topic;
+ }
+ this.payloadFn = spec.payloadFn();
+ this.retained = spec.retained();
}
@Setup
public void createMqttClient() throws Exception {
LOG.debug("Starting MQTT writer");
- client = spec.connectionConfiguration().createClient();
+ this.client = this.spec.connectionConfiguration().createClient();
LOG.debug("MQTT writer client ID is {}", client.getClientId());
- connection = createConnection(client);
+ this.connection = createConnection(client);
}
@ProcessElement
public void processElement(ProcessContext context) throws Exception {
- byte[] payload = context.element();
+ InputT element = context.element();
+ byte[] payload = this.payloadFn.apply(element);
+ String topic = this.topicFn.apply(element);
LOG.debug("Sending message {}", new String(payload, StandardCharsets.UTF_8));
- connection.publish(
- spec.connectionConfiguration().getTopic(), payload, QoS.AT_LEAST_ONCE, false);
+ this.connection.publish(topic, payload, QoS.AT_LEAST_ONCE, this.retained);
}
@Teardown
public void closeMqttClient() throws Exception {
- if (connection != null) {
+ if (this.connection != null) {
LOG.debug("Disconnecting MQTT connection (client ID {})", client.getClientId());
- connection.disconnect();
+ this.connection.disconnect();
}
}
}
diff --git a/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java b/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java
index 7d60d6d65780..8dfa7838d66a 100644
--- a/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java
+++ b/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java
@@ -18,6 +18,7 @@
package org.apache.beam.sdk.io.mqtt;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import java.io.ByteArrayInputStream;
@@ -26,16 +27,25 @@
import java.io.ObjectOutputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
import java.util.Set;
import java.util.UUID;
+import java.util.concurrent.ConcurrentSkipListMap;
import java.util.concurrent.ConcurrentSkipListSet;
import org.apache.activemq.broker.BrokerService;
import org.apache.activemq.broker.Connection;
+import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.common.NetworkTestHelper;
import org.apache.beam.sdk.io.mqtt.MqttIO.Read;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.fusesource.hawtbuf.Buffer;
import org.fusesource.mqtt.client.BlockingConnection;
@@ -266,6 +276,216 @@ public void testWrite() throws Exception {
}
}
+ @Test
+ public void testDynamicWrite() throws Exception {
+ final int numberOfTopic1Count = 100;
+ final int numberOfTopic2Count = 100;
+ final int numberOfTestMessages = numberOfTopic1Count + numberOfTopic2Count;
+
+ MQTT client = new MQTT();
+ client.setHost("tcp://localhost:" + port);
+ final BlockingConnection connection = client.blockingConnection();
+ connection.connect();
+ final String writeTopic1 = "WRITE_TOPIC_1";
+ final String writeTopic2 = "WRITE_TOPIC_2";
+ connection.subscribe(
+ new Topic[] {
+ new Topic(Buffer.utf8(writeTopic1), QoS.EXACTLY_ONCE),
+ new Topic(Buffer.utf8(writeTopic2), QoS.EXACTLY_ONCE)
+ });
+
+ final Map> messageMap = new ConcurrentSkipListMap<>();
+ final Thread subscriber =
+ new Thread(
+ () -> {
+ try {
+ for (int i = 0; i < numberOfTestMessages; i++) {
+ Message message = connection.receive();
+ List messages = messageMap.get(message.getTopic());
+ if (messages == null) {
+ messages = new ArrayList<>();
+ }
+ messages.add(new String(message.getPayload(), StandardCharsets.UTF_8));
+ messageMap.put(message.getTopic(), messages);
+ message.ack();
+ }
+ } catch (Exception e) {
+ LOG.error("Can't receive message", e);
+ }
+ });
+
+ subscriber.start();
+
+ ArrayList> data = new ArrayList<>();
+ for (int i = 0; i < numberOfTopic1Count; i++) {
+ data.add(KV.of(writeTopic1, ("Test" + i).getBytes(StandardCharsets.UTF_8)));
+ }
+
+ for (int i = 0; i < numberOfTopic2Count; i++) {
+ data.add(KV.of(writeTopic2, ("Test" + i).getBytes(StandardCharsets.UTF_8)));
+ }
+
+ pipeline
+ .apply(Create.of(data))
+ .setCoder(KvCoder.of(StringUtf8Coder.of(), ByteArrayCoder.of()))
+ .apply(
+ MqttIO.>dynamicWrite()
+ .withConnectionConfiguration(
+ MqttIO.ConnectionConfiguration.create("tcp://localhost:" + port)
+ .withClientId("READ_PIPELINE"))
+ .withTopicFn(input -> input.getKey())
+ .withPayloadFn(input -> input.getValue()));
+
+ pipeline.run();
+ subscriber.join();
+
+ connection.disconnect();
+
+ assertEquals(
+ numberOfTestMessages, messageMap.values().stream().mapToLong(Collection::size).sum());
+
+ assertEquals(2, messageMap.keySet().size());
+ assertTrue(messageMap.containsKey(writeTopic1));
+ assertTrue(messageMap.containsKey(writeTopic2));
+ for (Map.Entry> entry : messageMap.entrySet()) {
+ final List messages = entry.getValue();
+ messages.forEach(message -> assertTrue(message.contains("Test")));
+ if (entry.getKey().equals(writeTopic1)) {
+ assertEquals(numberOfTopic1Count, messages.size());
+ } else {
+ assertEquals(numberOfTopic2Count, messages.size());
+ }
+ }
+ }
+
+ @Test
+ public void testReadHaveNoConnectionConfiguration() {
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class, () -> MqttIO.read().expand(PBegin.in(pipeline)));
+
+ assertEquals("connectionConfiguration can not be null", exception.getMessage());
+ }
+
+ @Test
+ public void testReadHaveNoTopic() {
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ MqttIO.read()
+ .withConnectionConfiguration(MqttIO.ConnectionConfiguration.create("serverUri"))
+ .expand(PBegin.in(pipeline)));
+
+ assertEquals("topic can not be null", exception.getMessage());
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testWriteHaveNoConnectionConfiguration() {
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> MqttIO.write().expand(pipeline.apply(Create.of(new byte[] {}))));
+
+ assertEquals("connectionConfiguration can not be null", exception.getMessage());
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testWriteHaveNoTopic() {
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ MqttIO.write()
+ .withConnectionConfiguration(MqttIO.ConnectionConfiguration.create("serverUri"))
+ .expand(pipeline.apply(Create.of(new byte[] {}))));
+
+ assertEquals("topic can not be null", exception.getMessage());
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testDynamicWriteHaveNoConnectionConfiguration() {
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> MqttIO.dynamicWrite().expand(pipeline.apply(Create.of(new byte[] {}))));
+
+ assertEquals("connectionConfiguration can not be null", exception.getMessage());
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testDynamicWriteHaveNoTopicFn() {
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ MqttIO.dynamicWrite()
+ .withConnectionConfiguration(MqttIO.ConnectionConfiguration.create("serverUri"))
+ .expand(pipeline.apply(Create.of(new byte[] {}))));
+
+ assertEquals("topicFn can not be null", exception.getMessage());
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testDynamicWriteHaveNoPayloadFn() {
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ MqttIO.dynamicWrite()
+ .withConnectionConfiguration(MqttIO.ConnectionConfiguration.create("serverUri"))
+ .withTopicFn(input -> "topic")
+ .expand(pipeline.apply(Create.of(new byte[] {}))));
+
+ assertEquals("payloadFn can not be null", exception.getMessage());
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testDynamicWriteHaveStaticTopic() {
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class,
+ () ->
+ MqttIO.dynamicWrite()
+ .withConnectionConfiguration(
+ MqttIO.ConnectionConfiguration.create("serverUri", "topic"))
+ .expand(pipeline.apply(Create.of(new byte[] {}))));
+
+ assertEquals("DynamicWrite can not have static topic", exception.getMessage());
+
+ pipeline.run();
+ }
+
+ @Test
+ public void testWriteWithTopicFn() {
+ IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class, () -> MqttIO.write().withTopicFn(e -> "some topic"));
+
+ assertEquals("withTopicFn can not use in non-dynamic write", exception.getMessage());
+ }
+
+ @Test
+ public void testWriteWithPayloadFn() {
+ final IllegalArgumentException exception =
+ assertThrows(
+ IllegalArgumentException.class, () -> MqttIO.write().withPayloadFn(e -> new byte[] {}));
+
+ assertEquals("withPayloadFn can not use in non-dynamic write", exception.getMessage());
+ }
+
@Test
public void testReadObject() throws Exception {
ByteArrayOutputStream bos = new ByteArrayOutputStream();