Skip to content

Commit

Permalink
[#21250] Trivial removal of loop over something that always has one e…
Browse files Browse the repository at this point in the history
…lement (#24014)

Multiplexing was put into the PCollectionConsumerRegistry a long time ago and this seems to have been missed during that migration.
  • Loading branch information
lukecwik authored Nov 8, 2022
1 parent a6a9b23 commit 7dac3f5
Showing 1 changed file with 54 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,8 @@
import org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.util.Durations;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableListMultimap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ListMultimap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Table;
Expand Down Expand Up @@ -200,8 +198,7 @@ static class Factory<InputT, RestrictionT, PositionT, WatermarkEstimatorStateT,
Coder<Timer<Object>> coder = entry.getValue().getValue();
if (!localName.equals("")
&& localName.equals(runner.parDoPayload.getOnWindowExpirationTimerFamilySpec())) {
context.addIncomingTimerEndpoint(
localName, coder, timer -> runner.processOnWindowExpiration(timer));
context.addIncomingTimerEndpoint(localName, coder, runner::processOnWindowExpiration);
} else {
context.addIncomingTimerEndpoint(
localName, coder, timer -> runner.processTimer(localName, timeDomain, timer));
Expand Down Expand Up @@ -230,10 +227,10 @@ static class Factory<InputT, RestrictionT, PositionT, WatermarkEstimatorStateT,
private final Map<TupleTag<?>, Coder<?>> outputCoders;
private final Map<String, KV<TimeDomain, Coder<Timer<Object>>>> timerFamilyInfos;
private final ParDoPayload parDoPayload;
private final ListMultimap<String, FnDataReceiver<WindowedValue<?>>> localNameToConsumer;
private final Map<String, FnDataReceiver<WindowedValue<?>>> localNameToConsumer;
private final BundleSplitListener splitListener;
private final BundleFinalizer bundleFinalizer;
private final Collection<FnDataReceiver<WindowedValue<OutputT>>> mainOutputConsumers;
private final FnDataReceiver<WindowedValue<OutputT>> mainOutputConsumer;

private final String mainInputId;
private final FnApiStateAccessor<?> stateAccessor;
Expand Down Expand Up @@ -478,10 +475,10 @@ static class Factory<InputT, RestrictionT, PositionT, WatermarkEstimatorStateT,
throw new IllegalArgumentException("Malformed ParDoPayload", exn);
}

ImmutableListMultimap.Builder<String, FnDataReceiver<WindowedValue<?>>>
localNameToConsumerBuilder = ImmutableListMultimap.builder();
ImmutableMap.Builder<String, FnDataReceiver<WindowedValue<?>>> localNameToConsumerBuilder =
ImmutableMap.builder();
for (Map.Entry<String, String> entry : pTransform.getOutputsMap().entrySet()) {
localNameToConsumerBuilder.putAll(
localNameToConsumerBuilder.put(
entry.getKey(), getPCollectionConsumer.apply(entry.getValue()));
}
localNameToConsumer = localNameToConsumerBuilder.build();
Expand All @@ -491,9 +488,9 @@ static class Factory<InputT, RestrictionT, PositionT, WatermarkEstimatorStateT,
this.onTimerContext = new OnTimerContext();
this.onWindowExpirationContext = new OnWindowExpirationContext<>();

this.mainOutputConsumers =
(Collection<FnDataReceiver<WindowedValue<OutputT>>>)
(Collection) localNameToConsumer.get(mainOutputTag.getId());
this.mainOutputConsumer =
(FnDataReceiver<WindowedValue<OutputT>>)
(FnDataReceiver) localNameToConsumer.get(mainOutputTag.getId());
this.doFnSchemaInformation = ParDoTranslation.getSchemaInformation(parDoPayload);
this.sideInputMapping = ParDoTranslation.getSideInputMapping(parDoPayload);
this.doFnInvoker = DoFnInvokers.tryInvokeSetupFor(doFn, pipelineOptions);
Expand Down Expand Up @@ -569,12 +566,10 @@ static class Factory<InputT, RestrictionT, PositionT, WatermarkEstimatorStateT,
|| (doFnSignature.getSize() != null && doFnSignature.getSize().observesWindow())
|| !sideInputMapping.isEmpty()) {
// Only forward split/progress when the only consumer is splittable.
if (mainOutputConsumers.size() == 1
&& Iterables.getOnlyElement(mainOutputConsumers) instanceof HandlesSplits) {
if (mainOutputConsumer instanceof HandlesSplits) {
mainInputConsumer =
new SplittableFnDataReceiver() {
private final HandlesSplits splitDelegate =
(HandlesSplits) Iterables.getOnlyElement(mainOutputConsumers);
private final HandlesSplits splitDelegate = (HandlesSplits) mainOutputConsumer;

@Override
public void accept(WindowedValue input) throws Exception {
Expand Down Expand Up @@ -609,12 +604,10 @@ public double getProgress() {
PTransformTranslation.SPLITTABLE_TRUNCATE_SIZED_RESTRICTION_URN);
} else {
// Only forward split/progress when the only consumer is splittable.
if (mainOutputConsumers.size() == 1
&& Iterables.getOnlyElement(mainOutputConsumers) instanceof HandlesSplits) {
if (mainOutputConsumer instanceof HandlesSplits) {
mainInputConsumer =
new SplittableFnDataReceiver() {
private final HandlesSplits splitDelegate =
(HandlesSplits) Iterables.getOnlyElement(mainOutputConsumers);
private final HandlesSplits splitDelegate = (HandlesSplits) mainOutputConsumer;

@Override
public void accept(WindowedValue input) throws Exception {
Expand Down Expand Up @@ -830,7 +823,7 @@ private void processElementForPairWithRestriction(WindowedValue<InputT> elem) {
try {
currentRestriction = doFnInvoker.invokeGetInitialRestriction(processContext);
outputTo(
mainOutputConsumers,
mainOutputConsumer,
(WindowedValue)
elem.withValue(
KV.of(
Expand All @@ -855,7 +848,7 @@ private void processElementForWindowObservingPairWithRestriction(WindowedValue<I
currentWindow = windowIterator.next();
currentRestriction = doFnInvoker.invokeGetInitialRestriction(processContext);
outputTo(
mainOutputConsumers,
mainOutputConsumer,
(WindowedValue)
WindowedValue.of(
KV.of(
Expand Down Expand Up @@ -1787,16 +1780,13 @@ private void tearDown() {
}

/** Outputs the given element to the specified set of consumers wrapping any exceptions. */
private <T> void outputTo(
Collection<FnDataReceiver<WindowedValue<T>>> consumers, WindowedValue<T> output) {
private <T> void outputTo(FnDataReceiver<WindowedValue<T>> consumer, WindowedValue<T> output) {
if (currentWatermarkEstimator instanceof TimestampObservingWatermarkEstimator) {
((TimestampObservingWatermarkEstimator) currentWatermarkEstimator)
.observeTimestamp(output.getTimestamp());
}
try {
for (FnDataReceiver<WindowedValue<T>> consumer : consumers) {
consumer.accept(output);
}
consumer.accept(output);
} catch (Throwable t) {
throw UserCodeException.wrap(t);
}
Expand Down Expand Up @@ -2136,17 +2126,17 @@ public PipelineOptions getPipelineOptions() {
@Override
public void output(OutputT output, Instant timestamp, BoundedWindow window) {
outputTo(
mainOutputConsumers, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
mainOutputConsumer, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
}

@Override
public <T> void output(TupleTag<T> tag, T output, Instant timestamp, BoundedWindow window) {
Collection<FnDataReceiver<WindowedValue<T>>> consumers =
(Collection) localNameToConsumer.get(tag.getId());
if (consumers == null) {
FnDataReceiver<WindowedValue<T>> consumer =
(FnDataReceiver) localNameToConsumer.get(tag.getId());
if (consumer == null) {
throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
}
outputTo(consumers, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
outputTo(consumer, WindowedValue.of(output, timestamp, window, PaneInfo.NO_FIRING));
}
}

Expand Down Expand Up @@ -2248,20 +2238,20 @@ public TimerMap timerFamily(String timerFamilyId) {
public void outputWithTimestamp(OutputT output, Instant timestamp) {
// TODO: Check that timestamp is valid once all runners can provide proper timestamps.
outputTo(
mainOutputConsumers,
mainOutputConsumer,
WindowedValue.of(output, timestamp, currentWindow, currentElement.getPane()));
}

@Override
public <T> void outputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) {
// TODO: Check that timestamp is valid once all runners can provide proper timestamps.
Collection<FnDataReceiver<WindowedValue<T>>> consumers =
(Collection) localNameToConsumer.get(tag.getId());
if (consumers == null) {
FnDataReceiver<WindowedValue<T>> consumer =
(FnDataReceiver) localNameToConsumer.get(tag.getId());
if (consumer == null) {
throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
}
outputTo(
consumers, WindowedValue.of(output, timestamp, currentWindow, currentElement.getPane()));
consumer, WindowedValue.of(output, timestamp, currentWindow, currentElement.getPane()));
}
}

Expand Down Expand Up @@ -2299,7 +2289,7 @@ public Instant timestamp(DoFn<InputT, OutputT> doFn) {
});

outputTo(
mainOutputConsumers,
mainOutputConsumer,
(WindowedValue<OutputT>)
WindowedValue.of(
KV.of(
Expand Down Expand Up @@ -2346,7 +2336,7 @@ public Instant timestamp(DoFn<InputT, OutputT> doFn) {
});

outputTo(
mainOutputConsumers,
mainOutputConsumer,
(WindowedValue<OutputT>)
WindowedValue.of(
KV.of(
Expand All @@ -2365,21 +2355,21 @@ private class NonWindowObservingProcessBundleContext extends ProcessBundleContex
public void outputWithTimestamp(OutputT output, Instant timestamp) {
checkTimestamp(timestamp);
outputTo(
mainOutputConsumers,
mainOutputConsumer,
WindowedValue.of(
output, timestamp, currentElement.getWindows(), currentElement.getPane()));
}

@Override
public <T> void outputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) {
checkTimestamp(timestamp);
Collection<FnDataReceiver<WindowedValue<T>>> consumers =
(Collection) localNameToConsumer.get(tag.getId());
if (consumers == null) {
FnDataReceiver<WindowedValue<T>> consumer =
(FnDataReceiver) localNameToConsumer.get(tag.getId());
if (consumer == null) {
throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
}
outputTo(
consumers,
consumer,
WindowedValue.of(
output, timestamp, currentElement.getWindows(), currentElement.getPane()));
}
Expand Down Expand Up @@ -2591,7 +2581,7 @@ public BoundedWindow window() {
@Override
public void output(OutputT output) {
outputTo(
mainOutputConsumers,
mainOutputConsumer,
WindowedValue.of(
output, currentTimer.getHoldTimestamp(), currentWindow, currentTimer.getPane()));
}
Expand All @@ -2600,33 +2590,33 @@ public void output(OutputT output) {
public void outputWithTimestamp(OutputT output, Instant timestamp) {
checkOnWindowExpirationTimestamp(timestamp);
outputTo(
mainOutputConsumers,
mainOutputConsumer,
WindowedValue.of(output, timestamp, currentWindow, currentTimer.getPane()));
}

@Override
public <T> void output(TupleTag<T> tag, T output) {
Collection<FnDataReceiver<WindowedValue<T>>> consumers =
(Collection) localNameToConsumer.get(tag.getId());
if (consumers == null) {
FnDataReceiver<WindowedValue<T>> consumer =
(FnDataReceiver) localNameToConsumer.get(tag.getId());
if (consumer == null) {
throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
}
outputTo(
consumers,
consumer,
WindowedValue.of(
output, currentTimer.getHoldTimestamp(), currentWindow, currentTimer.getPane()));
}

@Override
public <T> void outputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) {
checkOnWindowExpirationTimestamp(timestamp);
Collection<FnDataReceiver<WindowedValue<T>>> consumers =
(Collection) localNameToConsumer.get(tag.getId());
if (consumers == null) {
FnDataReceiver<WindowedValue<T>> consumer =
(FnDataReceiver) localNameToConsumer.get(tag.getId());
if (consumer == null) {
throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
}
outputTo(
consumers, WindowedValue.of(output, timestamp, currentWindow, currentTimer.getPane()));
consumer, WindowedValue.of(output, timestamp, currentWindow, currentTimer.getPane()));
}

@SuppressWarnings(
Expand Down Expand Up @@ -2745,7 +2735,7 @@ public BoundedWindow window() {
public void output(OutputT output) {
checkTimerTimestamp(currentTimer.getHoldTimestamp());
outputTo(
mainOutputConsumers,
mainOutputConsumer,
WindowedValue.of(
output, currentTimer.getHoldTimestamp(), currentWindow, currentTimer.getPane()));
}
Expand All @@ -2754,34 +2744,34 @@ public void output(OutputT output) {
public void outputWithTimestamp(OutputT output, Instant timestamp) {
checkTimerTimestamp(timestamp);
outputTo(
mainOutputConsumers,
mainOutputConsumer,
WindowedValue.of(output, timestamp, currentWindow, currentTimer.getPane()));
}

@Override
public <T> void output(TupleTag<T> tag, T output) {
checkTimerTimestamp(currentTimer.getHoldTimestamp());
Collection<FnDataReceiver<WindowedValue<T>>> consumers =
(Collection) localNameToConsumer.get(tag.getId());
if (consumers == null) {
FnDataReceiver<WindowedValue<T>> consumer =
(FnDataReceiver) localNameToConsumer.get(tag.getId());
if (consumer == null) {
throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
}
outputTo(
consumers,
consumer,
WindowedValue.of(
output, currentTimer.getHoldTimestamp(), currentWindow, currentTimer.getPane()));
}

@Override
public <T> void outputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) {
checkTimerTimestamp(timestamp);
Collection<FnDataReceiver<WindowedValue<T>>> consumers =
(Collection) localNameToConsumer.get(tag.getId());
if (consumers == null) {
FnDataReceiver<WindowedValue<T>> consumer =
(FnDataReceiver) localNameToConsumer.get(tag.getId());
if (consumer == null) {
throw new IllegalArgumentException(String.format("Unknown output tag %s", tag));
}
outputTo(
consumers, WindowedValue.of(output, timestamp, currentWindow, currentTimer.getPane()));
consumer, WindowedValue.of(output, timestamp, currentWindow, currentTimer.getPane()));
}

@Override
Expand Down

0 comments on commit 7dac3f5

Please sign in to comment.