diff --git a/python/ray/llm/_internal/batch/stages/base.py b/python/ray/llm/_internal/batch/stages/base.py index 4248a88c7b102..666156af82d10 100644 --- a/python/ray/llm/_internal/batch/stages/base.py +++ b/python/ray/llm/_internal/batch/stages/base.py @@ -73,6 +73,10 @@ class StatefulStageUDF: column for the next stage. """ + # The internal column name for the index of the row in the batch. + # This is used to align the output of the UDF with the input batch. + idx_in_batch_column: str = "__idx_in_batch" + def __init__(self, data_column: str): self.data_column = data_column @@ -143,6 +147,10 @@ async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]] inputs = inputs.tolist() self.validate_inputs(inputs) + # Assign the index of the row in the batch to the idx_in_batch_column. + for idx, row in enumerate(inputs): + row[self.idx_in_batch_column] = idx + # Always stream the outputs one by one to better overlapping # batches. For example, when the output batch size is 64, Ray Data # will collect 64 outputs, and 1) send the batch of 64 to the next stage, @@ -151,12 +159,29 @@ async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]] # for 2 batches (63 + 63 > 64) to continue proceeding. On the other hand, # if we stream outputs one-by-one, Ray Data can form a batch of 64 before # the second batch is done. - idx = 0 + is_outputed = [False] * len(inputs) async for output in self.udf(inputs): + if self.idx_in_batch_column not in output: + raise ValueError( + "The output of the UDF must contain the column " + f"{self.idx_in_batch_column}." + ) + idx_in_batch = output.pop(self.idx_in_batch_column) + if is_outputed[idx_in_batch]: + raise ValueError( + f"The row {idx_in_batch} is outputed twice. " + "This is likely due to the UDF is not one-to-one." + ) + is_outputed[idx_in_batch] = True + # Add stage outputs to the data column of the row. - inputs[idx].update(output) - yield {self.data_column: [inputs[idx]]} - idx += 1 + inputs[idx_in_batch].pop(self.idx_in_batch_column) + inputs[idx_in_batch].update(output) + yield {self.data_column: [inputs[idx_in_batch]]} + + if not all(is_outputed): + missing_rows = [i for i, o in enumerate(is_outputed) if not o] + raise ValueError(f"The rows {missing_rows} are not outputed.") def validate_inputs(self, inputs: List[Dict[str, Any]]): """Validate the inputs to make sure the required keys are present. @@ -167,11 +192,18 @@ def validate_inputs(self, inputs: List[Dict[str, Any]]): Raises: ValueError: If the required keys are not found. """ + input_keys = set(inputs[0].keys()) + + if self.idx_in_batch_column in input_keys: + raise ValueError( + f"The input column {self.idx_in_batch_column} is reserved " + "for internal use." + ) + expected_input_keys = self.expected_input_keys if not expected_input_keys: return - input_keys = set(inputs[0].keys()) missing_required = set(expected_input_keys) - input_keys if missing_required: raise ValueError( diff --git a/python/ray/llm/_internal/batch/stages/chat_template_stage.py b/python/ray/llm/_internal/batch/stages/chat_template_stage.py index f41ee1aaec159..1995d45b42c23 100644 --- a/python/ray/llm/_internal/batch/stages/chat_template_stage.py +++ b/python/ray/llm/_internal/batch/stages/chat_template_stage.py @@ -37,12 +37,18 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] Yields: A generator of rows with the chat template applied. """ - for prompt in self.tokenizer.apply_chat_template( - [row["messages"].tolist() for row in batch], - tokenize=False, - add_generation_prompt=True, + for row, prompt in zip( + batch, + self.tokenizer.apply_chat_template( + [row["messages"].tolist() for row in batch], + tokenize=False, + add_generation_prompt=True, + ), ): - yield {"prompt": prompt} + yield { + self.idx_in_batch_column: row[self.idx_in_batch_column], + "prompt": prompt, + } @property def expected_input_keys(self) -> List[str]: diff --git a/python/ray/llm/_internal/batch/stages/http_request_stage.py b/python/ray/llm/_internal/batch/stages/http_request_stage.py index 9b765ec2b8d02..20b6083efda36 100644 --- a/python/ray/llm/_internal/batch/stages/http_request_stage.py +++ b/python/ray/llm/_internal/batch/stages/http_request_stage.py @@ -74,12 +74,21 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] headers=headers, json=json_body, ) - pending_requests.append(request) + pending_requests.append((row[self.idx_in_batch_column], request)) # Now receive all responses - for request in pending_requests: + for idx_in_batch, request in pending_requests: async with await request as response: - yield await response.json() + resp_json = await response.json() + if self.idx_in_batch_column in resp_json: + raise ValueError( + "The response of the HTTP request must not contain " + f"the column {self.idx_in_batch_column}." + ) + yield { + self.idx_in_batch_column: idx_in_batch, + **resp_json, + } class HttpRequestStage(StatefulStage): diff --git a/python/ray/llm/_internal/batch/stages/prepare_image_stage.py b/python/ray/llm/_internal/batch/stages/prepare_image_stage.py index 88641a8343ab7..4c933cd96b838 100644 --- a/python/ray/llm/_internal/batch/stages/prepare_image_stage.py +++ b/python/ray/llm/_internal/batch/stages/prepare_image_stage.py @@ -342,18 +342,30 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] flat_all_image_info = [img for imgs in all_image_info for img in imgs] flat_all_images = await self.image_processor.process(flat_all_image_info) - idx = 0 + # TODO: We now use asyncio.gather to process all images in this batch, + # so the outputs here must be in order. However, it is more efficient + # to support out-of-order outputs so that we won't be blocked by slow + # downloaded images. + img_start_idx = 0 + idx_in_batch = 0 for image_info_per_req in all_image_info: num_images_in_req = len(image_info_per_req) + ret = {self.idx_in_batch_column: idx_in_batch} if num_images_in_req == 0: - yield {} + yield ret else: - images = flat_all_images[idx : idx + num_images_in_req] - yield { - "image": images, - "image_sizes": [(img.width, img.height) for img in images], - } - idx += num_images_in_req + images = flat_all_images[ + img_start_idx : img_start_idx + num_images_in_req + ] + ret.update( + { + "image": images, + "image_sizes": [(img.width, img.height) for img in images], + } + ) + yield ret + img_start_idx += num_images_in_req + idx_in_batch += 1 @property def expected_input_keys(self) -> List[str]: diff --git a/python/ray/llm/_internal/batch/stages/tokenize_stage.py b/python/ray/llm/_internal/batch/stages/tokenize_stage.py index 68e7bc53e3c0c..b686e81ac75cc 100644 --- a/python/ray/llm/_internal/batch/stages/tokenize_stage.py +++ b/python/ray/llm/_internal/batch/stages/tokenize_stage.py @@ -41,7 +41,10 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] batch, self.tokenizer([row["prompt"] for row in batch])["input_ids"], ): - yield {"tokenized_prompt": prompt_token_ids} + yield { + self.idx_in_batch_column: row[self.idx_in_batch_column], + "tokenized_prompt": prompt_token_ids, + } @property def expected_input_keys(self) -> List[str]: @@ -96,7 +99,10 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] skip_special_tokens=True, ), ): - yield {"generated_text": generated_text} + yield { + self.idx_in_batch_column: row[self.idx_in_batch_column], + "generated_text": generated_text, + } @property def expected_input_keys(self) -> List[str]: diff --git a/python/ray/llm/tests/batch/stages/test_base.py b/python/ray/llm/tests/batch/stages/test_base.py index c99754a23b598..0e98a33b64254 100644 --- a/python/ray/llm/tests/batch/stages/test_base.py +++ b/python/ray/llm/tests/batch/stages/test_base.py @@ -40,11 +40,26 @@ def to_string(x: dict) -> dict: class TestStatefulStageUDF: class SimpleUDF(StatefulStageUDF): + def __init__( + self, + data_column: str, + udf_output_missing_idx_in_batch_column: bool = False, + ): + super().__init__(data_column) + self.udf_output_missing_idx_in_batch_column = ( + udf_output_missing_idx_in_batch_column + ) + async def udf( self, rows: list[Dict[str, Any]] ) -> AsyncIterator[Dict[str, Any]]: - for row in rows: - yield {"processed": row["value"] * 2} + + # Intentionally output in a reversed order to test OOO. + for row in rows[::-1]: + ret = {"processed": row["value"] * 2} + if not self.udf_output_missing_idx_in_batch_column: + ret[self.idx_in_batch_column] = row[self.idx_in_batch_column] + yield ret @property def expected_input_keys(self) -> List[str]: @@ -55,7 +70,7 @@ async def test_basic_processing(self): udf = self.SimpleUDF(data_column="__data") batch = { - "__data": [{"value": 1, "extra": "a"}, {"value": 2, "extra": "b"}], + "__data": [{"value": 1, "extra": 10}, {"value": 2, "extra": 20}], } results = [] @@ -63,12 +78,12 @@ async def test_basic_processing(self): results.append(result) assert len(results) == 2 - assert results[0] == { - "__data": [{"processed": 2, "value": 1, "extra": "a"}], - } - assert results[1] == { - "__data": [{"processed": 4, "value": 2, "extra": "b"}], - } + for result in results: + data = result["__data"][0] + val = data["value"] + assert data["processed"] == val * 2 + assert data["extra"] == 10 * val + assert data["value"] == val @pytest.mark.asyncio async def test_missing_data_column(self): @@ -90,6 +105,19 @@ async def test_missing_required_key(self): async for _ in udf(batch): pass + @pytest.mark.asyncio + async def test_missing_idx_in_batch_column(self): + udf = self.SimpleUDF( + data_column="__data", + udf_output_missing_idx_in_batch_column=True, + ) + + batch = {"__data": [{"value": 1}]} + + with pytest.raises(ValueError): + async for _ in udf(batch): + pass + def test_stateful_stage(): udf = TestStatefulStageUDF.SimpleUDF(data_column="__data") diff --git a/python/ray/llm/tests/batch/stages/test_chat_template_stage.py b/python/ray/llm/tests/batch/stages/test_chat_template_stage.py index 4d822d6fcd439..f5a05a0fd93bd 100644 --- a/python/ray/llm/tests/batch/stages/test_chat_template_stage.py +++ b/python/ray/llm/tests/batch/stages/test_chat_template_stage.py @@ -24,25 +24,23 @@ async def test_chat_template_udf_basic(mock_tokenizer_setup): model="test-model", ) - batch = [ - { - "messages": MagicMock( - tolist=lambda: [{"role": "user", "content": "Hello AI"}] - ) - } - ] + batch = { + "__data": [ + { + "messages": MagicMock( + tolist=lambda: [{"role": "user", "content": "Hello AI"}] + ) + } + ] + } results = [] - async for result in udf.udf(batch): + async for result in udf(batch): results.append(result) assert len(results) == 1 - assert results[0] == {"prompt": "Hello AI"} - mock_tokenizer.apply_chat_template.assert_called_once_with( - [batch[0]["messages"].tolist()], - tokenize=False, - add_generation_prompt=True, - ) + assert results[0]["__data"][0]["prompt"] == "Hello AI" + mock_tokenizer.apply_chat_template.assert_called_once() @pytest.mark.asyncio @@ -58,31 +56,29 @@ async def test_chat_template_udf_multiple_messages(mock_tokenizer_setup): model="test-model", ) - batch = [ - { - "messages": MagicMock( - tolist=lambda: [{"role": "user", "content": "Hello AI"}] - ) - }, - { - "messages": MagicMock( - tolist=lambda: [{"role": "user", "content": "How are you?"}] - ) - }, - ] + batch = { + "__data": [ + { + "messages": MagicMock( + tolist=lambda: [{"role": "user", "content": "Hello AI"}] + ) + }, + { + "messages": MagicMock( + tolist=lambda: [{"role": "user", "content": "How are you?"}], + ) + }, + ] + } results = [] - async for result in udf.udf(batch): + async for result in udf(batch): results.append(result) assert len(results) == 2 - assert results[0] == {"prompt": "Hello AI"} - assert results[1] == {"prompt": "How are you?"} - mock_tokenizer.apply_chat_template.assert_called_once_with( - [msg["messages"].tolist() for msg in batch], - tokenize=False, - add_generation_prompt=True, - ) + assert results[0]["__data"][0]["prompt"] == "Hello AI" + assert results[1]["__data"][0]["prompt"] == "How are you?" + mock_tokenizer.apply_chat_template.assert_called_once() def test_chat_template_udf_expected_input_keys(mock_tokenizer_setup): diff --git a/python/ray/llm/tests/batch/stages/test_http_request_stage.py b/python/ray/llm/tests/batch/stages/test_http_request_stage.py index 636979f893ac6..2477399bb0167 100644 --- a/python/ray/llm/tests/batch/stages/test_http_request_stage.py +++ b/python/ray/llm/tests/batch/stages/test_http_request_stage.py @@ -32,7 +32,7 @@ async def test_http_request_udf_basic(): qps=None, ) - batch = [{"text": "hello", "metadata": "test"}] + batch = {"__data": [{"text": "hello", "metadata": "test"}]} with patch("aiohttp.ClientSession") as mock_session_cls: session = AsyncMock() @@ -41,8 +41,8 @@ async def test_http_request_udf_basic(): ) mock_session_cls.return_value.__aenter__.return_value = session - async for result in udf.udf(batch): - assert result == {"response": "test"} + async for result in udf(batch): + assert result["__data"][0]["response"] == "test" session.post.assert_called_once_with( "http://test.com/api", @@ -50,7 +50,7 @@ async def test_http_request_udf_basic(): "Content-Type": "application/json", "Authorization": "Bearer 1234567890", }, - json={"text": "hello", "metadata": "test"}, + json={"text": "hello", "metadata": "test", "__idx_in_batch": 0}, ) @@ -62,7 +62,7 @@ async def test_http_request_udf_with_qps(): qps=2, ) - batch = [{"text": "hello1"}, {"text": "hello2"}] + batch = {"__data": [{"text": "hello1"}, {"text": "hello2"}]} with patch("aiohttp.ClientSession") as mock_session_cls, patch( "time.time" @@ -80,7 +80,7 @@ async def test_http_request_udf_with_qps(): mock_time.side_effect = [0, 0.1, 0.2] results = [] - async for result in udf.udf(batch): + async for result in udf(batch): results.append(result) assert len(results) == 2 diff --git a/python/ray/llm/tests/batch/stages/test_prepare_image_stage.py b/python/ray/llm/tests/batch/stages/test_prepare_image_stage.py index d91885c99ae0f..0d6d14481eea8 100644 --- a/python/ray/llm/tests/batch/stages/test_prepare_image_stage.py +++ b/python/ray/llm/tests/batch/stages/test_prepare_image_stage.py @@ -46,21 +46,23 @@ async def test_prepare_image_udf_basic(mock_image_processor, mock_image): udf = PrepareImageUDF(data_column="__data") # Test batch with one message containing an image URL - batch = [ - { - "messages": [ - { - "content": [ - {"type": "image", "image": "http://example.com/image.jpg"} - ] - } - ] - } - ] + batch = { + "__data": [ + { + "messages": [ + { + "content": [ + {"type": "image", "image": "http://example.com/image.jpg"} + ] + } + ] + } + ] + } results = [] - async for result in udf.udf(batch): - results.append(result) + async for result in udf(batch): + results.append(result["__data"][0]) assert len(results) == 1 assert "image" in results[0] @@ -74,22 +76,24 @@ async def test_prepare_image_udf_multiple_images(mock_image_processor, mock_imag udf = PrepareImageUDF(data_column="__data") # Test batch with multiple images in one message - batch = [ - { - "messages": [ - { - "content": [ - {"type": "image", "image": "http://example.com/image1.jpg"}, - {"type": "image", "image": "http://example.com/image2.jpg"}, - ] - } - ] - } - ] + batch = { + "__data": [ + { + "messages": [ + { + "content": [ + {"type": "image", "image": "http://example.com/image1.jpg"}, + {"type": "image", "image": "http://example.com/image2.jpg"}, + ] + } + ] + } + ] + } results = [] - async for result in udf.udf(batch): - results.append(result) + async for result in udf(batch): + results.append(result["__data"][0]) assert len(results) == 1 assert len(results[0]["image"]) == 2 @@ -101,14 +105,14 @@ async def test_prepare_image_udf_no_images(mock_image_processor): udf = PrepareImageUDF(data_column="__data") # Test batch with no images - batch = [{"messages": [{"content": "Hello, world!"}]}] + batch = {"__data": [{"messages": [{"content": "Hello, world!"}]}]} results = [] - async for result in udf.udf(batch): - results.append(result) + async for result in udf(batch): + results.append(result["__data"][0]) assert len(results) == 1 - assert results[0] == {} + assert results[0] == {"messages": [{"content": "Hello, world!"}]} @pytest.mark.asyncio @@ -143,14 +147,16 @@ async def test_prepare_image_udf_invalid_image_type(mock_image_processor): udf = PrepareImageUDF(data_column="__data") # Test batch with invalid image type - batch = [ - { - "messages": [ - {"content": [{"type": "image", "image": 123}]} # Invalid image type - ] - } - ] + batch = { + "__data": [ + { + "messages": [ + {"content": [{"type": "image", "image": 123}]} # Invalid image type + ] + } + ] + } with pytest.raises(ValueError, match="Cannot handle image type"): - async for _ in udf.udf(batch): + async for _ in udf(batch): pass diff --git a/python/ray/llm/tests/batch/stages/test_tokenize_stage.py b/python/ray/llm/tests/batch/stages/test_tokenize_stage.py index 8e940e10ebc0e..b3ccee1cdc945 100644 --- a/python/ray/llm/tests/batch/stages/test_tokenize_stage.py +++ b/python/ray/llm/tests/batch/stages/test_tokenize_stage.py @@ -26,11 +26,11 @@ async def test_tokenize_udf_basic(mock_tokenizer_setup): ] udf = TokenizeUDF(data_column="__data", model="test-model") - batch = [{"prompt": "Hello"}, {"prompt": "World"}] + batch = {"__data": [{"prompt": "Hello"}, {"prompt": "World"}]} results = [] - async for result in udf.udf(batch): - results.append(result) + async for result in udf(batch): + results.append(result["__data"][0]) assert len(results) == 2 assert all(result["tokenized_prompt"] == [1, 2, 3] for result in results) @@ -46,14 +46,16 @@ async def test_detokenize_udf_basic(mock_tokenizer_setup): mock_tokenizer.batch_decode.return_value = ["Hello", "World"] udf = DetokenizeUDF(data_column="__data", model="test-model") - batch = [ - {"generated_tokens": [1, 2, 3]}, - {"generated_tokens": [4, 5, 6]}, - ] + batch = { + "__data": [ + {"generated_tokens": [1, 2, 3]}, + {"generated_tokens": [4, 5, 6]}, + ] + } results = [] - async for result in udf.udf(batch): - results.append(result) + async for result in udf(batch): + results.append(result["__data"][0]) assert len(results) == 2 assert results[0]["generated_text"] == "Hello"