Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -843,8 +843,17 @@ public InputT element(DoFn<InputT, OutputT> doFn) {
}

@Override
public Object sideInput(String tagId) {
throw new UnsupportedOperationException("SideInput parameters are not supported.");
public @Nullable Object sideInput(String tagId) {
PCollectionView<?> view =
checkStateNotNull(sideInputMapping.get(tagId), "Side input tag %s not found", tagId);
return sideInput(view);
}

@Override
public <T> T sideInput(PCollectionView<T> view) {
checkNotNull(view, "View passed to sideInput cannot be null");
return SimpleDoFnRunner.this.sideInput(
view, view.getWindowMappingFn().getSideInputWindow(window()));
}

@Override
Expand Down Expand Up @@ -1147,8 +1156,17 @@ public InputT element(DoFn<InputT, OutputT> doFn) {
}

@Override
public Object sideInput(String tagId) {
throw new UnsupportedOperationException("SideInput parameters are not supported.");
public @Nullable Object sideInput(String tagId) {
PCollectionView<?> view =
checkStateNotNull(sideInputMapping.get(tagId), "Side input tag %s not found", tagId);
return sideInput(view);
}

@Override
public <T> T sideInput(PCollectionView<T> view) {
checkNotNull(view, "View passed to sideInput cannot be null");
return SimpleDoFnRunner.this.sideInput(
view, view.getWindowMappingFn().getSideInputWindow(window()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.apache.beam.runners.core.SideInputReader;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.StreamingOptions;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.util.WindowedValueMultiReceiver;
Expand Down Expand Up @@ -68,17 +67,6 @@ public DoFnRunner<InputT, OutputT> createRunner(
windowingStrategy,
doFnSchemaInformation,
sideInputMapping);
boolean hasStreamingSideInput =
options.as(StreamingOptions.class).isStreaming() && !sideInputReader.isEmpty();
if (hasStreamingSideInput) {
return new StreamingSideInputDoFnRunner<>(
fnRunner,
new StreamingSideInputFetcher<>(
sideInputViews,
inputCoder,
windowingStrategy,
(StreamingModeExecutionContext.StreamingModeStepContext) userStepContext));
}
return fnRunner;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;
import org.joda.time.Instant;
Expand All @@ -76,7 +77,7 @@
"rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
public class SimpleParDoFn<InputT, OutputT> implements ParDoFn {
public class SimpleParDoFn<InputT, OutputT, W extends BoundedWindow> implements ParDoFn {

// TODO: Remove once Distributions has shipped.
@VisibleForTesting
Expand Down Expand Up @@ -112,6 +113,8 @@ public class SimpleParDoFn<InputT, OutputT> implements ParDoFn {
// GroupAlsoByWindowViaWindowSetDoFn
private @Nullable DoFnSignature fnSignature;

private @Nullable StreamingSideInputProcessor<InputT, W> sideInputProcessor;

/** Creates a {@link SimpleParDoFn} using basic information about the step being executed. */
SimpleParDoFn(
PipelineOptions options,
Expand Down Expand Up @@ -317,8 +320,31 @@ public <TagT> void output(TupleTag<TagT> tag, WindowedValue<TagT> output) {
outputManager,
doFnSchemaInformation,
sideInputMapping);
if (hasStreamingSideInput) {
sideInputProcessor =
new StreamingSideInputProcessor<>(
new StreamingSideInputFetcher<InputT, W>(
fnInfo.getSideInputViews(),
fnInfo.getInputCoder(),
(WindowingStrategy<?, W>) fnInfo.getWindowingStrategy(),
(StreamingModeExecutionContext.StreamingModeStepContext) userStepContext));
}

fnRunner.startBundle();
if (sideInputProcessor != null) {
boolean hasState = fnSignature != null && !fnSignature.stateDeclarations().isEmpty();
Iterable<WindowedValue<InputT>> unblockedElements = sideInputProcessor.tryUnblockElements();
for (WindowedValue<InputT> unblockedElement : unblockedElements) {
fnRunner.processElement(unblockedElement);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The outputsPerElementTracker is not notified when processing unblocked elements in startBundle. This can lead to incorrect progress tracking and metrics in Dataflow, as the tracker expects onProcessElement() to be called before processElement(). You should wrap the processElement call with the appropriate tracker notifications.

      for (WindowedValue<InputT> unblockedElement : unblockedElements) {
        outputsPerElementTracker.onProcessElement();
        fnRunner.processElement(unblockedElement);
        outputsPerElementTracker.onProcessElementSuccess();

if (hasState) {
// These elements are now processed. Register cleanup timers for all the unblocked
// windows.
registerStateCleanup(
(WindowingStrategy<?, W>) getDoFnInfo().getWindowingStrategy(),
(Collection<W>) unblockedElement.getWindows());
}
}
}
}

@Override
Expand All @@ -334,14 +360,28 @@ public void processElement(Object untypedElem) throws Exception {

WindowedValue<InputT> elem = (WindowedValue<InputT>) untypedElem;

if (fnSignature != null && fnSignature.stateDeclarations().size() > 0) {
boolean hasState = fnSignature != null && !fnSignature.stateDeclarations().isEmpty();
outputsPerElementTracker.onProcessElement();

Collection<W> windowsProcessed;
if (sideInputProcessor != null) {
windowsProcessed = Lists.newArrayList();
Iterable<WindowedValue<InputT>> elementsToProcess =
sideInputProcessor.handleProcessElement(elem);
for (WindowedValue<InputT> toProcess : elementsToProcess) {
fnRunner.processElement(toProcess);
windowsProcessed.addAll((Collection<W>) toProcess.getWindows());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The windowsProcessed list is allocated and populated even if hasState is false, in which case it is never used. Deferring the allocation and population of this list until it's known that state cleanup is required would be more efficient.

Suggested change
windowsProcessed = Lists.newArrayList();
Iterable<WindowedValue<InputT>> elementsToProcess =
sideInputProcessor.handleProcessElement(elem);
for (WindowedValue<InputT> toProcess : elementsToProcess) {
fnRunner.processElement(toProcess);
windowsProcessed.addAll((Collection<W>) toProcess.getWindows());
windowsProcessed = hasState ? Lists.newArrayList() : ImmutableList.of();
Iterable<WindowedValue<InputT>> elementsToProcess =
sideInputProcessor.handleProcessElement(elem);
for (WindowedValue<InputT> toProcess : elementsToProcess) {
fnRunner.processElement(toProcess);
if (hasState) {
windowsProcessed.addAll((Collection<W>) toProcess.getWindows());
}

// If the element was blocked, don't register a cleanup timer. The timer will be registered
// when the window is unblocked ensuring that it is not processed until the element is.
}
} else {
fnRunner.processElement(elem);
windowsProcessed = (Collection<W>) elem.getWindows();
}
if (hasState) {
registerStateCleanup(
(WindowingStrategy<?, BoundedWindow>) getDoFnInfo().getWindowingStrategy(),
(Collection<BoundedWindow>) elem.getWindows());
(WindowingStrategy<?, W>) getDoFnInfo().getWindowingStrategy(), windowsProcessed);
}

outputsPerElementTracker.onProcessElement();
fnRunner.processElement(elem);
outputsPerElementTracker.onProcessElementSuccess();
}

Expand All @@ -367,6 +407,9 @@ private void processUserTimer(TimerData timer) throws Exception {
if (fnSignature.timerDeclarations().containsKey(timer.getTimerId())
|| fnSignature.timerFamilyDeclarations().containsKey(timer.getTimerFamilyId())) {
BoundedWindow window = ((WindowNamespace) timer.getNamespace()).getWindow();
if (sideInputProcessor != null) {
sideInputProcessor.handleProcessTimer(timer);
}
fnRunner.onTimer(
timer.getTimerId(),
timer.getTimerFamilyId(),
Expand All @@ -380,7 +423,6 @@ private void processUserTimer(TimerData timer) throws Exception {
}

private void processSystemTimer(TimerData timer) throws Exception {

// Timer owned by this class, for cleaning up state in expired windows
if (timer.getTimerId().equals(CLEANUP_TIMER_ID)) {
checkState(
Expand All @@ -396,6 +438,13 @@ private void processSystemTimer(TimerData timer) throws Exception {
WindowNamespace.class.getSimpleName(),
timer);

if (sideInputProcessor != null) {
// We must call this to ensure the side-input is cached for onWindowExpiration. Since we
// don't set cleanup
// timers until we actually call processElement, the window must be unblocked here.
sideInputProcessor.handleProcessTimer(timer);
}

BoundedWindow window = ((WindowNamespace) timer.getNamespace()).getWindow();
Instant targetTime = earliestAllowableCleanupTime(window, fnInfo.getWindowingStrategy());

Expand Down Expand Up @@ -436,10 +485,14 @@ private void processSystemTimer(TimerData timer) throws Exception {
public void finishBundle() throws Exception {
if (fnRunner != null) {
fnRunner.finishBundle();
if (sideInputProcessor != null) {
sideInputProcessor.handleFinishBundle();
}
doFnInstanceManager.complete(fnInfo);
fnRunner = null;
fnInfo = null;
fnSignature = null;
sideInputProcessor = null;
}
}

Expand Down Expand Up @@ -490,7 +543,7 @@ private void processTimers(
}
}

private <W extends BoundedWindow> void registerStateCleanup(
private void registerStateCleanup(
WindowingStrategy<?, W> windowingStrategy, Collection<W> windowsToCleanup) {
Coder<W> windowCoder = windowingStrategy.getWindowFn().windowCoder();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
*/
package org.apache.beam.runners.dataflow.worker;

import java.util.Set;
import org.apache.beam.runners.core.DoFnRunner;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
Expand All @@ -37,43 +35,30 @@
public class StreamingSideInputDoFnRunner<InputT, OutputT, W extends BoundedWindow>
implements DoFnRunner<InputT, OutputT> {
private final DoFnRunner<InputT, OutputT> simpleDoFnRunner;
private final StreamingSideInputFetcher<InputT, W> sideInputFetcher;
private final StreamingSideInputProcessor<InputT, W> sideInputProcessor;

public StreamingSideInputDoFnRunner(
DoFnRunner<InputT, OutputT> simpleDoFnRunner,
StreamingSideInputFetcher<InputT, W> sideInputFetcher) {
this.simpleDoFnRunner = simpleDoFnRunner;
this.sideInputFetcher = sideInputFetcher;
this.sideInputProcessor = new StreamingSideInputProcessor<>(sideInputFetcher);
}

@Override
public void startBundle() {
simpleDoFnRunner.startBundle();
sideInputFetcher.prefetchBlockedMap();

// Find the set of ready windows.
Set<W> readyWindows = sideInputFetcher.getReadyWindows();

Iterable<BagState<WindowedValue<InputT>>> elementsBags =
sideInputFetcher.prefetchElements(readyWindows);

// Run the DoFn code now that all side inputs are ready.
for (BagState<WindowedValue<InputT>> elementsBag : elementsBags) {
Iterable<WindowedValue<InputT>> elements = elementsBag.read();
for (WindowedValue<InputT> elem : elements) {
simpleDoFnRunner.processElement(elem);
}
elementsBag.clear();
Iterable<WindowedValue<InputT>> unblocked = sideInputProcessor.tryUnblockElements();
for (WindowedValue<InputT> elem : unblocked) {
simpleDoFnRunner.processElement(elem);
}
sideInputFetcher.releaseBlockedWindows(readyWindows);
}

@Override
public void processElement(WindowedValue<InputT> compressedElem) {
for (WindowedValue<InputT> elem : compressedElem.explodeWindows()) {
if (!sideInputFetcher.storeIfBlocked(elem)) {
simpleDoFnRunner.processElement(elem);
}
Iterable<WindowedValue<InputT>> unblocked =
sideInputProcessor.handleProcessElement(compressedElem);
for (WindowedValue<InputT> elem : unblocked) {
simpleDoFnRunner.processElement(elem);
}
}

Expand All @@ -94,7 +79,7 @@ public <KeyT> void onTimer(
@Override
public void finishBundle() {
simpleDoFnRunner.finishBundle();
sideInputFetcher.persist();
sideInputProcessor.handleFinishBundle();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -83,7 +84,6 @@ public StreamingSideInputFetcher(
this.stepContext = stepContext;

this.mainWindowCoder = windowingStrategy.getWindowFn().windowCoder();

this.sideInputViews = new HashMap<>();
for (PCollectionView<?> view : views) {
sideInputViews.put(view.getTagInternal().getId(), view);
Expand Down Expand Up @@ -188,11 +188,7 @@ public Iterable<BagState<TimerData>> prefetchTimers(Iterable<W> readyWindows) {
return timers;
}

/** Compute the set of side inputs that are not yet ready for the given main input window. */
public boolean storeIfBlocked(WindowedValue<InputT> elem) {
@SuppressWarnings("unchecked")
W window = (W) Iterables.getOnlyElement(elem.getWindows());

private Set<Windmill.GlobalDataRequest> checkIfBlocked(W window) {
Set<Windmill.GlobalDataRequest> blocked = blockedMap().get(window);
if (blocked == null) {
for (PCollectionView<?> view : sideInputViews.values()) {
Expand All @@ -205,7 +201,16 @@ public boolean storeIfBlocked(WindowedValue<InputT> elem) {
}
}
}
if (blocked != null) {
return blocked == null ? Collections.emptySet() : blocked;
}

/** Compute the set of side inputs that are not yet ready for the given main input window. */
public boolean storeIfBlocked(WindowedValue<InputT> elem) {
@SuppressWarnings("unchecked")
W window = (W) Iterables.getOnlyElement(elem.getWindows());

Set<Windmill.GlobalDataRequest> blocked = checkIfBlocked(window);
if (!blocked.isEmpty()) {
elementBag(window).add(elem);
watermarkHold(window).add(elem.getTimestamp());
stepContext.addBlockingSideInputs(blocked);
Expand All @@ -223,17 +228,12 @@ public boolean storeIfBlocked(TimerData timer) {
@SuppressWarnings("unchecked")
WindowNamespace<W> windowNamespace = (WindowNamespace<W>) timer.getNamespace();
W window = windowNamespace.getWindow();

boolean blocked = false;
for (PCollectionView<?> view : sideInputViews.values()) {
if (!stepContext.issueSideInputFetch(view, window, SideInputState.UNKNOWN)) {
blocked = true;
}
}
if (blocked) {
Set<Windmill.GlobalDataRequest> blocked = checkIfBlocked(window);
if (!blocked.isEmpty()) {
timerBag(window).add(timer);
return true;
}
return blocked;
return false;
}

public void persist() {
Expand Down
Loading
Loading