diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java index 1e60348eee02..2b449e0200b1 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java @@ -119,10 +119,8 @@ import org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.util.Durations; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableListMultimap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ListMultimap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Table; @@ -200,8 +198,7 @@ static class Factory> coder = entry.getValue().getValue(); if (!localName.equals("") && localName.equals(runner.parDoPayload.getOnWindowExpirationTimerFamilySpec())) { - context.addIncomingTimerEndpoint( - localName, coder, timer -> runner.processOnWindowExpiration(timer)); + context.addIncomingTimerEndpoint(localName, coder, runner::processOnWindowExpiration); } else { context.addIncomingTimerEndpoint( localName, coder, timer -> runner.processTimer(localName, timeDomain, timer)); @@ -230,10 +227,10 @@ static class Factory, Coder> outputCoders; private final Map>>> timerFamilyInfos; private final ParDoPayload parDoPayload; - private final ListMultimap>> localNameToConsumer; + private final Map>> localNameToConsumer; private final BundleSplitListener splitListener; private final BundleFinalizer bundleFinalizer; - private final Collection>> mainOutputConsumers; + private final FnDataReceiver> mainOutputConsumer; private final String mainInputId; private final FnApiStateAccessor stateAccessor; @@ -478,10 +475,10 @@ static class Factory>> - localNameToConsumerBuilder = ImmutableListMultimap.builder(); + ImmutableMap.Builder>> localNameToConsumerBuilder = + ImmutableMap.builder(); for (Map.Entry entry : pTransform.getOutputsMap().entrySet()) { - localNameToConsumerBuilder.putAll( + localNameToConsumerBuilder.put( entry.getKey(), getPCollectionConsumer.apply(entry.getValue())); } localNameToConsumer = localNameToConsumerBuilder.build(); @@ -491,9 +488,9 @@ static class Factory(); - this.mainOutputConsumers = - (Collection>>) - (Collection) localNameToConsumer.get(mainOutputTag.getId()); + this.mainOutputConsumer = + (FnDataReceiver>) + (FnDataReceiver) localNameToConsumer.get(mainOutputTag.getId()); this.doFnSchemaInformation = ParDoTranslation.getSchemaInformation(parDoPayload); this.sideInputMapping = ParDoTranslation.getSideInputMapping(parDoPayload); this.doFnInvoker = DoFnInvokers.tryInvokeSetupFor(doFn, pipelineOptions); @@ -569,12 +566,10 @@ static class Factory elem) { try { currentRestriction = doFnInvoker.invokeGetInitialRestriction(processContext); outputTo( - mainOutputConsumers, + mainOutputConsumer, (WindowedValue) elem.withValue( KV.of( @@ -855,7 +848,7 @@ private void processElementForWindowObservingPairWithRestriction(WindowedValue void outputTo( - Collection>> consumers, WindowedValue output) { + private void outputTo(FnDataReceiver> consumer, WindowedValue output) { if (currentWatermarkEstimator instanceof TimestampObservingWatermarkEstimator) { ((TimestampObservingWatermarkEstimator) currentWatermarkEstimator) .observeTimestamp(output.getTimestamp()); } try { - for (FnDataReceiver> consumer : consumers) { - consumer.accept(output); - } + consumer.accept(output); } catch (Throwable t) { throw UserCodeException.wrap(t); } @@ -2136,17 +2126,17 @@ public PipelineOptions getPipelineOptions() { @Override public void output(OutputT output, Instant timestamp, BoundedWindow window) { outputTo( - mainOutputConsumers, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); + mainOutputConsumer, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); } @Override public void output(TupleTag tag, T output, Instant timestamp, BoundedWindow window) { - Collection>> consumers = - (Collection) localNameToConsumer.get(tag.getId()); - if (consumers == null) { + FnDataReceiver> consumer = + (FnDataReceiver) localNameToConsumer.get(tag.getId()); + if (consumer == null) { throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); } - outputTo(consumers, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); + outputTo(consumer, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING)); } } @@ -2248,20 +2238,20 @@ public TimerMap timerFamily(String timerFamilyId) { public void outputWithTimestamp(OutputT output, Instant timestamp) { // TODO: Check that timestamp is valid once all runners can provide proper timestamps. outputTo( - mainOutputConsumers, + mainOutputConsumer, WindowedValue.of(output, timestamp, currentWindow, currentElement.getPane())); } @Override public void outputWithTimestamp(TupleTag tag, T output, Instant timestamp) { // TODO: Check that timestamp is valid once all runners can provide proper timestamps. - Collection>> consumers = - (Collection) localNameToConsumer.get(tag.getId()); - if (consumers == null) { + FnDataReceiver> consumer = + (FnDataReceiver) localNameToConsumer.get(tag.getId()); + if (consumer == null) { throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); } outputTo( - consumers, WindowedValue.of(output, timestamp, currentWindow, currentElement.getPane())); + consumer, WindowedValue.of(output, timestamp, currentWindow, currentElement.getPane())); } } @@ -2299,7 +2289,7 @@ public Instant timestamp(DoFn doFn) { }); outputTo( - mainOutputConsumers, + mainOutputConsumer, (WindowedValue) WindowedValue.of( KV.of( @@ -2346,7 +2336,7 @@ public Instant timestamp(DoFn doFn) { }); outputTo( - mainOutputConsumers, + mainOutputConsumer, (WindowedValue) WindowedValue.of( KV.of( @@ -2365,7 +2355,7 @@ private class NonWindowObservingProcessBundleContext extends ProcessBundleContex public void outputWithTimestamp(OutputT output, Instant timestamp) { checkTimestamp(timestamp); outputTo( - mainOutputConsumers, + mainOutputConsumer, WindowedValue.of( output, timestamp, currentElement.getWindows(), currentElement.getPane())); } @@ -2373,13 +2363,13 @@ public void outputWithTimestamp(OutputT output, Instant timestamp) { @Override public void outputWithTimestamp(TupleTag tag, T output, Instant timestamp) { checkTimestamp(timestamp); - Collection>> consumers = - (Collection) localNameToConsumer.get(tag.getId()); - if (consumers == null) { + FnDataReceiver> consumer = + (FnDataReceiver) localNameToConsumer.get(tag.getId()); + if (consumer == null) { throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); } outputTo( - consumers, + consumer, WindowedValue.of( output, timestamp, currentElement.getWindows(), currentElement.getPane())); } @@ -2591,7 +2581,7 @@ public BoundedWindow window() { @Override public void output(OutputT output) { outputTo( - mainOutputConsumers, + mainOutputConsumer, WindowedValue.of( output, currentTimer.getHoldTimestamp(), currentWindow, currentTimer.getPane())); } @@ -2600,19 +2590,19 @@ public void output(OutputT output) { public void outputWithTimestamp(OutputT output, Instant timestamp) { checkOnWindowExpirationTimestamp(timestamp); outputTo( - mainOutputConsumers, + mainOutputConsumer, WindowedValue.of(output, timestamp, currentWindow, currentTimer.getPane())); } @Override public void output(TupleTag tag, T output) { - Collection>> consumers = - (Collection) localNameToConsumer.get(tag.getId()); - if (consumers == null) { + FnDataReceiver> consumer = + (FnDataReceiver) localNameToConsumer.get(tag.getId()); + if (consumer == null) { throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); } outputTo( - consumers, + consumer, WindowedValue.of( output, currentTimer.getHoldTimestamp(), currentWindow, currentTimer.getPane())); } @@ -2620,13 +2610,13 @@ public void output(TupleTag tag, T output) { @Override public void outputWithTimestamp(TupleTag tag, T output, Instant timestamp) { checkOnWindowExpirationTimestamp(timestamp); - Collection>> consumers = - (Collection) localNameToConsumer.get(tag.getId()); - if (consumers == null) { + FnDataReceiver> consumer = + (FnDataReceiver) localNameToConsumer.get(tag.getId()); + if (consumer == null) { throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); } outputTo( - consumers, WindowedValue.of(output, timestamp, currentWindow, currentTimer.getPane())); + consumer, WindowedValue.of(output, timestamp, currentWindow, currentTimer.getPane())); } @SuppressWarnings( @@ -2745,7 +2735,7 @@ public BoundedWindow window() { public void output(OutputT output) { checkTimerTimestamp(currentTimer.getHoldTimestamp()); outputTo( - mainOutputConsumers, + mainOutputConsumer, WindowedValue.of( output, currentTimer.getHoldTimestamp(), currentWindow, currentTimer.getPane())); } @@ -2754,20 +2744,20 @@ public void output(OutputT output) { public void outputWithTimestamp(OutputT output, Instant timestamp) { checkTimerTimestamp(timestamp); outputTo( - mainOutputConsumers, + mainOutputConsumer, WindowedValue.of(output, timestamp, currentWindow, currentTimer.getPane())); } @Override public void output(TupleTag tag, T output) { checkTimerTimestamp(currentTimer.getHoldTimestamp()); - Collection>> consumers = - (Collection) localNameToConsumer.get(tag.getId()); - if (consumers == null) { + FnDataReceiver> consumer = + (FnDataReceiver) localNameToConsumer.get(tag.getId()); + if (consumer == null) { throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); } outputTo( - consumers, + consumer, WindowedValue.of( output, currentTimer.getHoldTimestamp(), currentWindow, currentTimer.getPane())); } @@ -2775,13 +2765,13 @@ public void output(TupleTag tag, T output) { @Override public void outputWithTimestamp(TupleTag tag, T output, Instant timestamp) { checkTimerTimestamp(timestamp); - Collection>> consumers = - (Collection) localNameToConsumer.get(tag.getId()); - if (consumers == null) { + FnDataReceiver> consumer = + (FnDataReceiver) localNameToConsumer.get(tag.getId()); + if (consumer == null) { throw new IllegalArgumentException(String.format("Unknown output tag %s", tag)); } outputTo( - consumers, WindowedValue.of(output, timestamp, currentWindow, currentTimer.getPane())); + consumer, WindowedValue.of(output, timestamp, currentWindow, currentTimer.getPane())); } @Override