diff --git a/python/ray/data/_internal/execution/operators/union_operator.py b/python/ray/data/_internal/execution/operators/union_operator.py index 8eab50403390..8ecf8e7682e8 100644 --- a/python/ray/data/_internal/execution/operators/union_operator.py +++ b/python/ray/data/_internal/execution/operators/union_operator.py @@ -35,23 +35,22 @@ def __init__( self._preserve_order = False # Intermediary buffers used to store blocks from each input dependency. - # Only used when `self._prserve_order` is True. + # Only used when `self._preserve_order` is True. self._input_buffers: List[BundleQueue] = [ FIFOBundleQueue() for _ in range(len(input_ops)) ] - # The index of the input dependency that is currently the source of - # the output buffer. New inputs from this input dependency will be added - # directly to the output buffer. Only used when `self._preserve_order` is True. - self._input_idx_to_output = 0 + self._input_done_flags: List[bool] = [False] * len(input_ops) self._output_buffer: collections.deque[RefBundle] = collections.deque() self._stats: StatsDict = {"Union": []} + self._current_input_index = 0 super().__init__(data_context, *input_ops) def start(self, options: ExecutionOptions): - # Whether to preserve the order of the input data (both the - # order of the input operators and the order of the blocks within). + # Whether to preserve deterministic ordering of output blocks. + # When True, blocks are emitted in round-robin order across inputs, + # ensuring the same input always produces the same output order. self._preserve_order = options.preserve_order super().start(options) @@ -101,27 +100,27 @@ def clear_internal_output_queue(self) -> None: def _add_input_inner(self, refs: RefBundle, input_index: int) -> None: assert not self.has_completed() assert 0 <= input_index <= len(self._input_dependencies), input_index - - if not self._preserve_order: - self._output_buffer.append(refs) - self._metrics.on_output_queued(refs) - else: + if self._preserve_order: self._input_buffers[input_index].add(refs) self._metrics.on_input_queued(refs) + self._try_round_robin() + else: + self._output_buffer.append(refs) + self._metrics.on_output_queued(refs) + + def input_done(self, input_index: int) -> None: + self._input_done_flags[input_index] = True + if self._preserve_order: + self._try_round_robin() def all_inputs_done(self) -> None: super().all_inputs_done() if not self._preserve_order: return - - assert len(self._output_buffer) == 0, len(self._output_buffer) - for input_buffer in self._input_buffers: - while input_buffer: - refs = input_buffer.get_next() - self._metrics.on_input_dequeued(refs) - self._output_buffer.append(refs) - self._metrics.on_output_queued(refs) + while any(buffer.has_next() for buffer in self._input_buffers): + self._try_round_robin() + assert all(not buffer.has_next() for buffer in self._input_buffers) def has_next(self) -> bool: # Check if the output buffer still contains at least one block. @@ -134,3 +133,25 @@ def _get_next_inner(self) -> RefBundle: def get_stats(self) -> StatsDict: return self._stats + + def _try_round_robin(self) -> None: + """Try to move blocks from input buffers to output in round-robin order. + + Pulls one block from the current input, then advances to the next. + If the current input's buffer is empty but not exhausted, stops and + waits (blocking behavior) to maintain deterministic ordering. + """ + num_inputs = len(self._input_buffers) + + for _ in range(num_inputs): + buffer = self._input_buffers[self._current_input_index] + + if buffer.has_next(): + refs = buffer.get_next() + self._metrics.on_input_dequeued(refs) + self._output_buffer.append(refs) + self._metrics.on_output_queued(refs) + elif not self._input_done_flags[self._current_input_index]: + break + + self._current_input_index = (self._current_input_index + 1) % num_inputs