diff --git a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json index f1ba03a243ee..455144f02a35 100644 --- a/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json +++ b/.github/trigger_files/beam_PostCommit_Java_PVR_Spark_Batch.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 5 + "modification": 6 } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json index 9b023f630c36..03d86a8d023e 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark.json @@ -2,5 +2,6 @@ "comment": "Modify this file in a trivial way to cause this test suite to run", "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", "https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test", - "https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test" + "https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test", + "https://github.com/apache/beam/pull/33322": "noting that PR #33322 should run this test" } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json index 9b023f630c36..03d86a8d023e 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Spark_Java11.json @@ -2,5 +2,6 @@ "comment": "Modify this file in a trivial way to cause this test suite to run", "https://github.com/apache/beam/pull/31156": "noting that PR #31156 should run this test", "https://github.com/apache/beam/pull/31798": "noting that PR #31798 should run this test", - "https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test" + "https://github.com/apache/beam/pull/32546": "noting that PR #32546 should run this test", + "https://github.com/apache/beam/pull/33322": "noting that PR #33322 should run this test" } diff --git a/CHANGES.md b/CHANGES.md index 1b943a99f8a0..869644c52cec 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -67,6 +67,7 @@ ## New Features / Improvements +* Improved batch performance of SparkRunner's GroupByKey ([#20943](https://github.com/apache/beam/pull/20943)). * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). ## Breaking Changes diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java index 62c5e2579427..1d8901ed5ffc 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupCombineFunctions.java @@ -17,6 +17,9 @@ */ package org.apache.beam.runners.spark.translation; +import java.util.Iterator; +import java.util.List; +import java.util.stream.Collectors; import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.util.ByteArray; import org.apache.beam.sdk.coders.Coder; @@ -27,6 +30,7 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators; import org.apache.spark.Partitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; @@ -49,18 +53,36 @@ public static JavaRDD>>> groupByKeyOnly( @Nullable Partitioner partitioner) { // we use coders to convert objects in the PCollection to byte arrays, so they // can be transferred over the network for the shuffle. - JavaPairRDD pairRDD = - rdd.map(new ReifyTimestampsAndWindowsFunction<>()) - .mapToPair(TranslationUtils.toPairFunction()) - .mapToPair(CoderHelpers.toByteFunction(keyCoder, wvCoder)); - - // If no partitioner is passed, the default group by key operation is called - JavaPairRDD> groupedRDD = - (partitioner != null) ? pairRDD.groupByKey(partitioner) : pairRDD.groupByKey(); - - return groupedRDD - .mapToPair(CoderHelpers.fromByteFunctionIterable(keyCoder, wvCoder)) - .map(new TranslationUtils.FromPairFunction<>()); + final JavaPairRDD pairRDD = + rdd.mapPartitionsToPair( + (Iterator>> iter) -> + Iterators.transform( + iter, + (WindowedValue> wv) -> { + final K key = wv.getValue().getKey(); + final WindowedValue windowedValue = wv.withValue(wv.getValue().getValue()); + final ByteArray keyBytes = + new ByteArray(CoderHelpers.toByteArray(key, keyCoder)); + final byte[] windowedValueBytes = + CoderHelpers.toByteArray(windowedValue, wvCoder); + return Tuple2.apply(keyBytes, windowedValueBytes); + })); + + final JavaPairRDD> combined = + GroupNonMergingWindowsFunctions.combineByKey(pairRDD, partitioner).cache(); + + return combined.mapPartitions( + (Iterator>> iter) -> + Iterators.transform( + iter, + (Tuple2> tuple) -> { + final K key = CoderHelpers.fromByteArray(tuple._1().getValue(), keyCoder); + final List> windowedValues = + tuple._2().stream() + .map(bytes -> CoderHelpers.fromByteArray(bytes, wvCoder)) + .collect(Collectors.toList()); + return KV.of(key, windowedValues); + })); } /** diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctions.java index 2461d5cc8d66..14630fbb0a1f 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctions.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctions.java @@ -17,7 +17,9 @@ */ package org.apache.beam.runners.spark.translation; +import java.util.ArrayList; import java.util.Iterator; +import java.util.List; import java.util.Objects; import org.apache.beam.runners.spark.coders.CoderHelpers; import org.apache.beam.runners.spark.util.ByteArray; @@ -41,6 +43,9 @@ import org.apache.spark.Partitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -259,9 +264,12 @@ private WindowedValue> decodeItem(Tuple2 item) { } /** - * Group all values with a given key for that composite key with Spark's groupByKey, dropping the - * Window (which must be GlobalWindow) and returning the grouped result in the appropriate global - * window. + * Groups values with a given key using Spark's combineByKey operation in the Global Window + * context. The window information (which must be GlobalWindow) is dropped during processing, and + * the grouped results are returned in the appropriate global window with the maximum timestamp. + * + *

This implementation uses {@link JavaPairRDD#combineByKey} for better performance compared to + * {@link JavaPairRDD#groupByKey}, as it allows for local aggregation before shuffle operations. */ static JavaRDD>>> groupByKeyInGlobalWindow( @@ -269,24 +277,70 @@ JavaRDD>>> groupByKeyInGlobalWindow( Coder keyCoder, Coder valueCoder, Partitioner partitioner) { - JavaPairRDD rawKeyValues = - rdd.mapToPair( - wv -> - new Tuple2<>( - new ByteArray(CoderHelpers.toByteArray(wv.getValue().getKey(), keyCoder)), - CoderHelpers.toByteArray(wv.getValue().getValue(), valueCoder))); - - JavaPairRDD> grouped = - (partitioner == null) ? rawKeyValues.groupByKey() : rawKeyValues.groupByKey(partitioner); - return grouped.map( - kvs -> - WindowedValue.timestampedValueInGlobalWindow( - KV.of( - CoderHelpers.fromByteArray(kvs._1.getValue(), keyCoder), - Iterables.transform( - kvs._2, - encodedValue -> CoderHelpers.fromByteArray(encodedValue, valueCoder))), - GlobalWindow.INSTANCE.maxTimestamp(), - PaneInfo.ON_TIME_AND_ONLY_FIRING)); + final JavaPairRDD rawKeyValues = + rdd.mapPartitionsToPair( + (Iterator>> iter) -> + Iterators.transform( + iter, + (WindowedValue> wv) -> { + final ByteArray keyBytes = + new ByteArray(CoderHelpers.toByteArray(wv.getValue().getKey(), keyCoder)); + final byte[] valueBytes = + CoderHelpers.toByteArray(wv.getValue().getValue(), valueCoder); + return Tuple2.apply(keyBytes, valueBytes); + })); + + JavaPairRDD> combined = combineByKey(rawKeyValues, partitioner).cache(); + + return combined.mapPartitions( + (Iterator>> iter) -> + Iterators.transform( + iter, + kvs -> + WindowedValue.timestampedValueInGlobalWindow( + KV.of( + CoderHelpers.fromByteArray(kvs._1.getValue(), keyCoder), + Iterables.transform( + kvs._2(), + encodedValue -> + CoderHelpers.fromByteArray(encodedValue, valueCoder))), + GlobalWindow.INSTANCE.maxTimestamp(), + PaneInfo.ON_TIME_AND_ONLY_FIRING))); + } + + /** + * Combines values by key using Spark's {@link JavaPairRDD#combineByKey} operation. + * + * @param rawKeyValues Input RDD of key-value pairs + * @param partitioner Optional custom partitioner for data distribution + * @return RDD with values combined into Lists per key + */ + static JavaPairRDD> combineByKey( + JavaPairRDD rawKeyValues, @Nullable Partitioner partitioner) { + + final Function> createCombiner = + value -> { + List list = new ArrayList<>(); + list.add(value); + return list; + }; + + final Function2, byte[], List> mergeValues = + (list, value) -> { + list.add(value); + return list; + }; + + final Function2, List, List> mergeCombiners = + (list1, list2) -> { + list1.addAll(list2); + return list1; + }; + + if (partitioner == null) { + return rawKeyValues.combineByKey(createCombiner, mergeValues, mergeCombiners); + } + + return rawKeyValues.combineByKey(createCombiner, mergeValues, mergeCombiners, partitioner); } } diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctionsTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctionsTest.java index ed7bc078564e..fd299924af91 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctionsTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/translation/GroupNonMergingWindowsFunctionsTest.java @@ -18,12 +18,6 @@ package org.apache.beam.runners.spark.translation; import static org.junit.Assert.assertEquals; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import java.util.Arrays; import java.util.Iterator; @@ -45,9 +39,6 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.primitives.Bytes; -import org.apache.spark.Partitioner; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; import org.joda.time.Duration; import org.joda.time.Instant; import org.junit.Assert; @@ -121,54 +112,6 @@ public void testGbkIteratorValuesCannotBeReiterated() throws Coder.NonDeterminis } } - @Test - @SuppressWarnings({"rawtypes", "unchecked"}) - public void testGroupByKeyInGlobalWindowWithPartitioner() { - // mocking - Partitioner mockPartitioner = mock(Partitioner.class); - JavaRDD mockRdd = mock(JavaRDD.class); - Coder mockKeyCoder = mock(Coder.class); - Coder mockValueCoder = mock(Coder.class); - JavaPairRDD mockRawKeyValues = mock(JavaPairRDD.class); - JavaPairRDD mockGrouped = mock(JavaPairRDD.class); - - when(mockRdd.mapToPair(any())).thenReturn(mockRawKeyValues); - when(mockRawKeyValues.groupByKey(any(Partitioner.class))) - .thenAnswer( - invocation -> { - Partitioner partitioner = invocation.getArgument(0); - assertEquals(partitioner, mockPartitioner); - return mockGrouped; - }); - when(mockGrouped.map(any())).thenReturn(mock(JavaRDD.class)); - - GroupNonMergingWindowsFunctions.groupByKeyInGlobalWindow( - mockRdd, mockKeyCoder, mockValueCoder, mockPartitioner); - - verify(mockRawKeyValues, never()).groupByKey(); - verify(mockRawKeyValues, times(1)).groupByKey(any(Partitioner.class)); - } - - @Test - @SuppressWarnings({"rawtypes", "unchecked"}) - public void testGroupByKeyInGlobalWindowWithoutPartitioner() { - // mocking - JavaRDD mockRdd = mock(JavaRDD.class); - Coder mockKeyCoder = mock(Coder.class); - Coder mockValueCoder = mock(Coder.class); - JavaPairRDD mockRawKeyValues = mock(JavaPairRDD.class); - JavaPairRDD mockGrouped = mock(JavaPairRDD.class); - - when(mockRdd.mapToPair(any())).thenReturn(mockRawKeyValues); - when(mockRawKeyValues.groupByKey()).thenReturn(mockGrouped); - - GroupNonMergingWindowsFunctions.groupByKeyInGlobalWindow( - mockRdd, mockKeyCoder, mockValueCoder, null); - - verify(mockRawKeyValues, times(1)).groupByKey(); - verify(mockRawKeyValues, never()).groupByKey(any(Partitioner.class)); - } - private GroupByKeyIterator createGbkIterator() throws Coder.NonDeterministicException { return createGbkIterator(