diff --git a/CHANGES.md b/CHANGES.md index 7c377648c117..973ed762dfd0 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -65,6 +65,8 @@ * Added support for Zstd codec in SerializableAvroCodecFactory (Java) ([#32349](https://github.com/apache/beam/issues/32349)) * Added support for using vLLM in the RunInference transform (Python) ([#32528](https://github.com/apache/beam/issues/32528)) * Significantly improved performance of Kafka IO reads that enable [commitOffsetsInFinalize](https://beam.apache.org/releases/javadoc/current/org/apache/beam/sdk/io/kafka/KafkaIO.Read.html#commitOffsetsInFinalize--) by removing the data reshuffle from SDF implementation. ([#31682](https://github.com/apache/beam/pull/31682)). +* Added support for dynamic writing in MqttIO (Java) ([#19376](https://github.com/apache/beam/issues/19376)) +* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). ## Breaking Changes diff --git a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java index 0e584d564b5c..e1868e2c8461 100644 --- a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java +++ b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java @@ -39,6 +39,8 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.SerializableFunctions; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; @@ -99,6 +101,26 @@ * "my_topic")) * * } + * + *

Dynamic Writing to a MQTT Broker

+ * + *

MqttIO also supports dynamic writing to multiple topics based on the data. You can specify a + * function to determine the target topic for each message. The following example demonstrates how + * to configure dynamic topic writing: + * + *

{@code
+ * pipeline
+ *   .apply(...)  // Provide PCollection
+ *   .apply(
+ *     MqttIO.dynamicWrite()
+ *       .withConnectionConfiguration(
+ *         MqttIO.ConnectionConfiguration.create("tcp://host:11883"))
+ *       .withTopicFn()
+ *       .withPayloadFn());
+ * }
+ * + *

This dynamic writing capability allows for more flexible MQTT message routing based on the + * message content, enabling scenarios where messages are directed to different topics. */ @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) @@ -115,8 +137,16 @@ public static Read read() { .build(); } - public static Write write() { - return new AutoValue_MqttIO_Write.Builder().setRetained(false).build(); + public static Write write() { + return new AutoValue_MqttIO_Write.Builder() + .setRetained(false) + .setPayloadFn(SerializableFunctions.identity()) + .setDynamic(false) + .build(); + } + + public static Write dynamicWrite() { + return new AutoValue_MqttIO_Write.Builder().setRetained(false).setDynamic(true).build(); } private MqttIO() {} @@ -127,7 +157,7 @@ public abstract static class ConnectionConfiguration implements Serializable { abstract String getServerUri(); - abstract String getTopic(); + abstract @Nullable String getTopic(); abstract @Nullable String getClientId(); @@ -169,6 +199,11 @@ public static ConnectionConfiguration create(String serverUri, String topic) { .build(); } + public static ConnectionConfiguration create(String serverUri) { + checkArgument(serverUri != null, "serverUri can not be null"); + return new AutoValue_MqttIO_ConnectionConfiguration.Builder().setServerUri(serverUri).build(); + } + /** Set up the MQTT broker URI. */ public ConnectionConfiguration withServerUri(String serverUri) { checkArgument(serverUri != null, "serverUri can not be null"); @@ -199,7 +234,7 @@ public ConnectionConfiguration withPassword(String password) { private void populateDisplayData(DisplayData.Builder builder) { builder.add(DisplayData.item("serverUri", getServerUri())); - builder.add(DisplayData.item("topic", getTopic())); + builder.addIfNotNull(DisplayData.item("topic", getTopic())); builder.addIfNotNull(DisplayData.item("clientId", getClientId())); builder.addIfNotNull(DisplayData.item("username", getUsername())); } @@ -278,6 +313,9 @@ public Read withMaxReadTime(Duration maxReadTime) { @Override public PCollection expand(PBegin input) { + checkArgument(connectionConfiguration() != null, "connectionConfiguration can not be null"); + checkArgument(connectionConfiguration().getTopic() != null, "topic can not be null"); + org.apache.beam.sdk.io.Read.Unbounded unbounded = org.apache.beam.sdk.io.Read.from(new UnboundedMqttSource(this)); @@ -505,29 +543,50 @@ public UnboundedMqttSource getCurrentSource() { /** A {@link PTransform} to write and send a message to a MQTT server. */ @AutoValue - public abstract static class Write extends PTransform, PDone> { - + public abstract static class Write extends PTransform, PDone> { abstract @Nullable ConnectionConfiguration connectionConfiguration(); + abstract @Nullable SerializableFunction topicFn(); + + abstract @Nullable SerializableFunction payloadFn(); + + abstract boolean dynamic(); + abstract boolean retained(); - abstract Builder builder(); + abstract Builder builder(); @AutoValue.Builder - abstract static class Builder { - abstract Builder setConnectionConfiguration(ConnectionConfiguration configuration); + abstract static class Builder { + abstract Builder setConnectionConfiguration(ConnectionConfiguration configuration); + + abstract Builder setRetained(boolean retained); + + abstract Builder setTopicFn(SerializableFunction topicFn); - abstract Builder setRetained(boolean retained); + abstract Builder setPayloadFn(SerializableFunction payloadFn); - abstract Write build(); + abstract Builder setDynamic(boolean dynamic); + + abstract Write build(); } /** Define MQTT connection configuration used to connect to the MQTT broker. */ - public Write withConnectionConfiguration(ConnectionConfiguration configuration) { + public Write withConnectionConfiguration(ConnectionConfiguration configuration) { checkArgument(configuration != null, "configuration can not be null"); return builder().setConnectionConfiguration(configuration).build(); } + public Write withTopicFn(SerializableFunction topicFn) { + checkArgument(dynamic(), "withTopicFn can not use in non-dynamic write"); + return builder().setTopicFn(topicFn).build(); + } + + public Write withPayloadFn(SerializableFunction payloadFn) { + checkArgument(dynamic(), "withPayloadFn can not use in non-dynamic write"); + return builder().setPayloadFn(payloadFn).build(); + } + /** * Whether or not the publish message should be retained by the messaging engine. Sending a * message with the retained set to {@code false} will clear the retained message from the @@ -538,54 +597,76 @@ public Write withConnectionConfiguration(ConnectionConfiguration configuration) * @param retained Whether or not the messaging engine should retain the message. * @return The {@link Write} {@link PTransform} with the corresponding retained configuration. */ - public Write withRetained(boolean retained) { + public Write withRetained(boolean retained) { return builder().setRetained(retained).build(); } - @Override - public PDone expand(PCollection input) { - input.apply(ParDo.of(new WriteFn(this))); - return PDone.in(input.getPipeline()); - } - @Override public void populateDisplayData(DisplayData.Builder builder) { connectionConfiguration().populateDisplayData(builder); builder.add(DisplayData.item("retained", retained())); } - private static class WriteFn extends DoFn { + @Override + public PDone expand(PCollection input) { + checkArgument(connectionConfiguration() != null, "connectionConfiguration can not be null"); + if (dynamic()) { + checkArgument( + connectionConfiguration().getTopic() == null, "DynamicWrite can not have static topic"); + checkArgument(topicFn() != null, "topicFn can not be null"); + } else { + checkArgument(connectionConfiguration().getTopic() != null, "topic can not be null"); + } + checkArgument(payloadFn() != null, "payloadFn can not be null"); + + input.apply(ParDo.of(new WriteFn<>(this))); + return PDone.in(input.getPipeline()); + } + + private static class WriteFn extends DoFn { - private final Write spec; + private final Write spec; + private final SerializableFunction topicFn; + private final SerializableFunction payloadFn; + private final boolean retained; private transient MQTT client; private transient BlockingConnection connection; - public WriteFn(Write spec) { + public WriteFn(Write spec) { this.spec = spec; + if (spec.dynamic()) { + this.topicFn = spec.topicFn(); + } else { + String topic = spec.connectionConfiguration().getTopic(); + this.topicFn = ignore -> topic; + } + this.payloadFn = spec.payloadFn(); + this.retained = spec.retained(); } @Setup public void createMqttClient() throws Exception { LOG.debug("Starting MQTT writer"); - client = spec.connectionConfiguration().createClient(); + this.client = this.spec.connectionConfiguration().createClient(); LOG.debug("MQTT writer client ID is {}", client.getClientId()); - connection = createConnection(client); + this.connection = createConnection(client); } @ProcessElement public void processElement(ProcessContext context) throws Exception { - byte[] payload = context.element(); + InputT element = context.element(); + byte[] payload = this.payloadFn.apply(element); + String topic = this.topicFn.apply(element); LOG.debug("Sending message {}", new String(payload, StandardCharsets.UTF_8)); - connection.publish( - spec.connectionConfiguration().getTopic(), payload, QoS.AT_LEAST_ONCE, false); + this.connection.publish(topic, payload, QoS.AT_LEAST_ONCE, this.retained); } @Teardown public void closeMqttClient() throws Exception { - if (connection != null) { + if (this.connection != null) { LOG.debug("Disconnecting MQTT connection (client ID {})", client.getClientId()); - connection.disconnect(); + this.connection.disconnect(); } } } diff --git a/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java b/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java index 7d60d6d65780..8dfa7838d66a 100644 --- a/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java +++ b/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.mqtt; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import java.io.ByteArrayInputStream; @@ -26,16 +27,25 @@ import java.io.ObjectOutputStream; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; import java.util.Set; import java.util.UUID; +import java.util.concurrent.ConcurrentSkipListMap; import java.util.concurrent.ConcurrentSkipListSet; import org.apache.activemq.broker.BrokerService; import org.apache.activemq.broker.Connection; +import org.apache.beam.sdk.coders.ByteArrayCoder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.io.common.NetworkTestHelper; import org.apache.beam.sdk.io.mqtt.MqttIO.Read; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.fusesource.hawtbuf.Buffer; import org.fusesource.mqtt.client.BlockingConnection; @@ -266,6 +276,216 @@ public void testWrite() throws Exception { } } + @Test + public void testDynamicWrite() throws Exception { + final int numberOfTopic1Count = 100; + final int numberOfTopic2Count = 100; + final int numberOfTestMessages = numberOfTopic1Count + numberOfTopic2Count; + + MQTT client = new MQTT(); + client.setHost("tcp://localhost:" + port); + final BlockingConnection connection = client.blockingConnection(); + connection.connect(); + final String writeTopic1 = "WRITE_TOPIC_1"; + final String writeTopic2 = "WRITE_TOPIC_2"; + connection.subscribe( + new Topic[] { + new Topic(Buffer.utf8(writeTopic1), QoS.EXACTLY_ONCE), + new Topic(Buffer.utf8(writeTopic2), QoS.EXACTLY_ONCE) + }); + + final Map> messageMap = new ConcurrentSkipListMap<>(); + final Thread subscriber = + new Thread( + () -> { + try { + for (int i = 0; i < numberOfTestMessages; i++) { + Message message = connection.receive(); + List messages = messageMap.get(message.getTopic()); + if (messages == null) { + messages = new ArrayList<>(); + } + messages.add(new String(message.getPayload(), StandardCharsets.UTF_8)); + messageMap.put(message.getTopic(), messages); + message.ack(); + } + } catch (Exception e) { + LOG.error("Can't receive message", e); + } + }); + + subscriber.start(); + + ArrayList> data = new ArrayList<>(); + for (int i = 0; i < numberOfTopic1Count; i++) { + data.add(KV.of(writeTopic1, ("Test" + i).getBytes(StandardCharsets.UTF_8))); + } + + for (int i = 0; i < numberOfTopic2Count; i++) { + data.add(KV.of(writeTopic2, ("Test" + i).getBytes(StandardCharsets.UTF_8))); + } + + pipeline + .apply(Create.of(data)) + .setCoder(KvCoder.of(StringUtf8Coder.of(), ByteArrayCoder.of())) + .apply( + MqttIO.>dynamicWrite() + .withConnectionConfiguration( + MqttIO.ConnectionConfiguration.create("tcp://localhost:" + port) + .withClientId("READ_PIPELINE")) + .withTopicFn(input -> input.getKey()) + .withPayloadFn(input -> input.getValue())); + + pipeline.run(); + subscriber.join(); + + connection.disconnect(); + + assertEquals( + numberOfTestMessages, messageMap.values().stream().mapToLong(Collection::size).sum()); + + assertEquals(2, messageMap.keySet().size()); + assertTrue(messageMap.containsKey(writeTopic1)); + assertTrue(messageMap.containsKey(writeTopic2)); + for (Map.Entry> entry : messageMap.entrySet()) { + final List messages = entry.getValue(); + messages.forEach(message -> assertTrue(message.contains("Test"))); + if (entry.getKey().equals(writeTopic1)) { + assertEquals(numberOfTopic1Count, messages.size()); + } else { + assertEquals(numberOfTopic2Count, messages.size()); + } + } + } + + @Test + public void testReadHaveNoConnectionConfiguration() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, () -> MqttIO.read().expand(PBegin.in(pipeline))); + + assertEquals("connectionConfiguration can not be null", exception.getMessage()); + } + + @Test + public void testReadHaveNoTopic() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + MqttIO.read() + .withConnectionConfiguration(MqttIO.ConnectionConfiguration.create("serverUri")) + .expand(PBegin.in(pipeline))); + + assertEquals("topic can not be null", exception.getMessage()); + + pipeline.run(); + } + + @Test + public void testWriteHaveNoConnectionConfiguration() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> MqttIO.write().expand(pipeline.apply(Create.of(new byte[] {})))); + + assertEquals("connectionConfiguration can not be null", exception.getMessage()); + + pipeline.run(); + } + + @Test + public void testWriteHaveNoTopic() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + MqttIO.write() + .withConnectionConfiguration(MqttIO.ConnectionConfiguration.create("serverUri")) + .expand(pipeline.apply(Create.of(new byte[] {})))); + + assertEquals("topic can not be null", exception.getMessage()); + + pipeline.run(); + } + + @Test + public void testDynamicWriteHaveNoConnectionConfiguration() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> MqttIO.dynamicWrite().expand(pipeline.apply(Create.of(new byte[] {})))); + + assertEquals("connectionConfiguration can not be null", exception.getMessage()); + + pipeline.run(); + } + + @Test + public void testDynamicWriteHaveNoTopicFn() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + MqttIO.dynamicWrite() + .withConnectionConfiguration(MqttIO.ConnectionConfiguration.create("serverUri")) + .expand(pipeline.apply(Create.of(new byte[] {})))); + + assertEquals("topicFn can not be null", exception.getMessage()); + + pipeline.run(); + } + + @Test + public void testDynamicWriteHaveNoPayloadFn() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + MqttIO.dynamicWrite() + .withConnectionConfiguration(MqttIO.ConnectionConfiguration.create("serverUri")) + .withTopicFn(input -> "topic") + .expand(pipeline.apply(Create.of(new byte[] {})))); + + assertEquals("payloadFn can not be null", exception.getMessage()); + + pipeline.run(); + } + + @Test + public void testDynamicWriteHaveStaticTopic() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + MqttIO.dynamicWrite() + .withConnectionConfiguration( + MqttIO.ConnectionConfiguration.create("serverUri", "topic")) + .expand(pipeline.apply(Create.of(new byte[] {})))); + + assertEquals("DynamicWrite can not have static topic", exception.getMessage()); + + pipeline.run(); + } + + @Test + public void testWriteWithTopicFn() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, () -> MqttIO.write().withTopicFn(e -> "some topic")); + + assertEquals("withTopicFn can not use in non-dynamic write", exception.getMessage()); + } + + @Test + public void testWriteWithPayloadFn() { + final IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, () -> MqttIO.write().withPayloadFn(e -> new byte[] {})); + + assertEquals("withPayloadFn can not use in non-dynamic write", exception.getMessage()); + } + @Test public void testReadObject() throws Exception { ByteArrayOutputStream bos = new ByteArrayOutputStream();