From 896ff59560fb29e727dba03a14659df8cce2b486 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 31 Jan 2025 16:43:36 -0800 Subject: [PATCH 1/6] ooo Signed-off-by: Cody Yu --- python/ray/llm/_internal/batch/stages/base.py | 41 ++++++++++++++++--- .../batch/stages/chat_template_stage.py | 16 +++++--- .../batch/stages/http_request_stage.py | 15 +++++-- .../_internal/batch/stages/tokenize_stage.py | 10 ++++- 4 files changed, 67 insertions(+), 15 deletions(-) diff --git a/python/ray/llm/_internal/batch/stages/base.py b/python/ray/llm/_internal/batch/stages/base.py index 4248a88c7b10..8e172541e560 100644 --- a/python/ray/llm/_internal/batch/stages/base.py +++ b/python/ray/llm/_internal/batch/stages/base.py @@ -73,6 +73,10 @@ class StatefulStageUDF: column for the next stage. """ + # The internal column name for the index of the row in the batch. + # This is used to align the output of the UDF with the input batch. + idx_in_batch_column: str = "__idx_in_batch" + def __init__(self, data_column: str): self.data_column = data_column @@ -143,6 +147,10 @@ async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]] inputs = inputs.tolist() self.validate_inputs(inputs) + # Assign the index of the row in the batch to the idx_in_batch_column. + for idx, row in enumerate(inputs): + row[self.idx_in_batch_column] = idx + # Always stream the outputs one by one to better overlapping # batches. For example, when the output batch size is 64, Ray Data # will collect 64 outputs, and 1) send the batch of 64 to the next stage, @@ -151,12 +159,28 @@ async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]] # for 2 batches (63 + 63 > 64) to continue proceeding. On the other hand, # if we stream outputs one-by-one, Ray Data can form a batch of 64 before # the second batch is done. - idx = 0 + is_outputed = [False] * len(inputs) async for output in self.udf(inputs): + if self.idx_in_batch_column not in output: + raise ValueError( + "The output of the UDF must contain the column " + f"{self.idx_in_batch_column}." + ) + idx_in_batch = output.pop(self.idx_in_batch_column) + if is_outputed[idx_in_batch]: + raise ValueError( + f"The row {idx_in_batch} is outputed twice. " + "This is likely due to the UDF is not one-to-one." + ) + is_outputed[idx_in_batch] = True + # Add stage outputs to the data column of the row. - inputs[idx].update(output) - yield {self.data_column: [inputs[idx]]} - idx += 1 + inputs[idx_in_batch].update(output) + yield {self.data_column: [inputs[idx_in_batch]]} + + if not all(is_outputed): + missing_rows = [i for i, o in enumerate(is_outputed) if not o] + raise ValueError(f"The rows {missing_rows} are not outputed.") def validate_inputs(self, inputs: List[Dict[str, Any]]): """Validate the inputs to make sure the required keys are present. @@ -167,11 +191,18 @@ def validate_inputs(self, inputs: List[Dict[str, Any]]): Raises: ValueError: If the required keys are not found. """ + input_keys = set(inputs[0].keys()) + + if self.idx_in_batch_column in input_keys: + raise ValueError( + f"The input column {self.idx_in_batch_column} is reserved " + "for internal use." + ) + expected_input_keys = self.expected_input_keys if not expected_input_keys: return - input_keys = set(inputs[0].keys()) missing_required = set(expected_input_keys) - input_keys if missing_required: raise ValueError( diff --git a/python/ray/llm/_internal/batch/stages/chat_template_stage.py b/python/ray/llm/_internal/batch/stages/chat_template_stage.py index f41ee1aaec15..1995d45b42c2 100644 --- a/python/ray/llm/_internal/batch/stages/chat_template_stage.py +++ b/python/ray/llm/_internal/batch/stages/chat_template_stage.py @@ -37,12 +37,18 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] Yields: A generator of rows with the chat template applied. """ - for prompt in self.tokenizer.apply_chat_template( - [row["messages"].tolist() for row in batch], - tokenize=False, - add_generation_prompt=True, + for row, prompt in zip( + batch, + self.tokenizer.apply_chat_template( + [row["messages"].tolist() for row in batch], + tokenize=False, + add_generation_prompt=True, + ), ): - yield {"prompt": prompt} + yield { + self.idx_in_batch_column: row[self.idx_in_batch_column], + "prompt": prompt, + } @property def expected_input_keys(self) -> List[str]: diff --git a/python/ray/llm/_internal/batch/stages/http_request_stage.py b/python/ray/llm/_internal/batch/stages/http_request_stage.py index 9b765ec2b8d0..20b6083efda3 100644 --- a/python/ray/llm/_internal/batch/stages/http_request_stage.py +++ b/python/ray/llm/_internal/batch/stages/http_request_stage.py @@ -74,12 +74,21 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] headers=headers, json=json_body, ) - pending_requests.append(request) + pending_requests.append((row[self.idx_in_batch_column], request)) # Now receive all responses - for request in pending_requests: + for idx_in_batch, request in pending_requests: async with await request as response: - yield await response.json() + resp_json = await response.json() + if self.idx_in_batch_column in resp_json: + raise ValueError( + "The response of the HTTP request must not contain " + f"the column {self.idx_in_batch_column}." + ) + yield { + self.idx_in_batch_column: idx_in_batch, + **resp_json, + } class HttpRequestStage(StatefulStage): diff --git a/python/ray/llm/_internal/batch/stages/tokenize_stage.py b/python/ray/llm/_internal/batch/stages/tokenize_stage.py index 68e7bc53e3c0..b686e81ac75c 100644 --- a/python/ray/llm/_internal/batch/stages/tokenize_stage.py +++ b/python/ray/llm/_internal/batch/stages/tokenize_stage.py @@ -41,7 +41,10 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] batch, self.tokenizer([row["prompt"] for row in batch])["input_ids"], ): - yield {"tokenized_prompt": prompt_token_ids} + yield { + self.idx_in_batch_column: row[self.idx_in_batch_column], + "tokenized_prompt": prompt_token_ids, + } @property def expected_input_keys(self) -> List[str]: @@ -96,7 +99,10 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] skip_special_tokens=True, ), ): - yield {"generated_text": generated_text} + yield { + self.idx_in_batch_column: row[self.idx_in_batch_column], + "generated_text": generated_text, + } @property def expected_input_keys(self) -> List[str]: From 1b9897f9aee3c32c40829eb9239e5d22ef68506e Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 31 Jan 2025 17:10:54 -0800 Subject: [PATCH 2/6] done Signed-off-by: Cody Yu --- .../ray/llm/tests/batch/stages/test_base.py | 46 ++++++++++--- .../batch/stages/test_chat_template_stage.py | 64 +++++++++---------- .../batch/stages/test_http_request_stage.py | 12 ++-- .../tests/batch/stages/test_tokenize_stage.py | 20 +++--- 4 files changed, 84 insertions(+), 58 deletions(-) diff --git a/python/ray/llm/tests/batch/stages/test_base.py b/python/ray/llm/tests/batch/stages/test_base.py index c99754a23b59..0e98a33b6425 100644 --- a/python/ray/llm/tests/batch/stages/test_base.py +++ b/python/ray/llm/tests/batch/stages/test_base.py @@ -40,11 +40,26 @@ def to_string(x: dict) -> dict: class TestStatefulStageUDF: class SimpleUDF(StatefulStageUDF): + def __init__( + self, + data_column: str, + udf_output_missing_idx_in_batch_column: bool = False, + ): + super().__init__(data_column) + self.udf_output_missing_idx_in_batch_column = ( + udf_output_missing_idx_in_batch_column + ) + async def udf( self, rows: list[Dict[str, Any]] ) -> AsyncIterator[Dict[str, Any]]: - for row in rows: - yield {"processed": row["value"] * 2} + + # Intentionally output in a reversed order to test OOO. + for row in rows[::-1]: + ret = {"processed": row["value"] * 2} + if not self.udf_output_missing_idx_in_batch_column: + ret[self.idx_in_batch_column] = row[self.idx_in_batch_column] + yield ret @property def expected_input_keys(self) -> List[str]: @@ -55,7 +70,7 @@ async def test_basic_processing(self): udf = self.SimpleUDF(data_column="__data") batch = { - "__data": [{"value": 1, "extra": "a"}, {"value": 2, "extra": "b"}], + "__data": [{"value": 1, "extra": 10}, {"value": 2, "extra": 20}], } results = [] @@ -63,12 +78,12 @@ async def test_basic_processing(self): results.append(result) assert len(results) == 2 - assert results[0] == { - "__data": [{"processed": 2, "value": 1, "extra": "a"}], - } - assert results[1] == { - "__data": [{"processed": 4, "value": 2, "extra": "b"}], - } + for result in results: + data = result["__data"][0] + val = data["value"] + assert data["processed"] == val * 2 + assert data["extra"] == 10 * val + assert data["value"] == val @pytest.mark.asyncio async def test_missing_data_column(self): @@ -90,6 +105,19 @@ async def test_missing_required_key(self): async for _ in udf(batch): pass + @pytest.mark.asyncio + async def test_missing_idx_in_batch_column(self): + udf = self.SimpleUDF( + data_column="__data", + udf_output_missing_idx_in_batch_column=True, + ) + + batch = {"__data": [{"value": 1}]} + + with pytest.raises(ValueError): + async for _ in udf(batch): + pass + def test_stateful_stage(): udf = TestStatefulStageUDF.SimpleUDF(data_column="__data") diff --git a/python/ray/llm/tests/batch/stages/test_chat_template_stage.py b/python/ray/llm/tests/batch/stages/test_chat_template_stage.py index 4d822d6fcd43..f5a05a0fd93b 100644 --- a/python/ray/llm/tests/batch/stages/test_chat_template_stage.py +++ b/python/ray/llm/tests/batch/stages/test_chat_template_stage.py @@ -24,25 +24,23 @@ async def test_chat_template_udf_basic(mock_tokenizer_setup): model="test-model", ) - batch = [ - { - "messages": MagicMock( - tolist=lambda: [{"role": "user", "content": "Hello AI"}] - ) - } - ] + batch = { + "__data": [ + { + "messages": MagicMock( + tolist=lambda: [{"role": "user", "content": "Hello AI"}] + ) + } + ] + } results = [] - async for result in udf.udf(batch): + async for result in udf(batch): results.append(result) assert len(results) == 1 - assert results[0] == {"prompt": "Hello AI"} - mock_tokenizer.apply_chat_template.assert_called_once_with( - [batch[0]["messages"].tolist()], - tokenize=False, - add_generation_prompt=True, - ) + assert results[0]["__data"][0]["prompt"] == "Hello AI" + mock_tokenizer.apply_chat_template.assert_called_once() @pytest.mark.asyncio @@ -58,31 +56,29 @@ async def test_chat_template_udf_multiple_messages(mock_tokenizer_setup): model="test-model", ) - batch = [ - { - "messages": MagicMock( - tolist=lambda: [{"role": "user", "content": "Hello AI"}] - ) - }, - { - "messages": MagicMock( - tolist=lambda: [{"role": "user", "content": "How are you?"}] - ) - }, - ] + batch = { + "__data": [ + { + "messages": MagicMock( + tolist=lambda: [{"role": "user", "content": "Hello AI"}] + ) + }, + { + "messages": MagicMock( + tolist=lambda: [{"role": "user", "content": "How are you?"}], + ) + }, + ] + } results = [] - async for result in udf.udf(batch): + async for result in udf(batch): results.append(result) assert len(results) == 2 - assert results[0] == {"prompt": "Hello AI"} - assert results[1] == {"prompt": "How are you?"} - mock_tokenizer.apply_chat_template.assert_called_once_with( - [msg["messages"].tolist() for msg in batch], - tokenize=False, - add_generation_prompt=True, - ) + assert results[0]["__data"][0]["prompt"] == "Hello AI" + assert results[1]["__data"][0]["prompt"] == "How are you?" + mock_tokenizer.apply_chat_template.assert_called_once() def test_chat_template_udf_expected_input_keys(mock_tokenizer_setup): diff --git a/python/ray/llm/tests/batch/stages/test_http_request_stage.py b/python/ray/llm/tests/batch/stages/test_http_request_stage.py index 636979f893ac..2477399bb016 100644 --- a/python/ray/llm/tests/batch/stages/test_http_request_stage.py +++ b/python/ray/llm/tests/batch/stages/test_http_request_stage.py @@ -32,7 +32,7 @@ async def test_http_request_udf_basic(): qps=None, ) - batch = [{"text": "hello", "metadata": "test"}] + batch = {"__data": [{"text": "hello", "metadata": "test"}]} with patch("aiohttp.ClientSession") as mock_session_cls: session = AsyncMock() @@ -41,8 +41,8 @@ async def test_http_request_udf_basic(): ) mock_session_cls.return_value.__aenter__.return_value = session - async for result in udf.udf(batch): - assert result == {"response": "test"} + async for result in udf(batch): + assert result["__data"][0]["response"] == "test" session.post.assert_called_once_with( "http://test.com/api", @@ -50,7 +50,7 @@ async def test_http_request_udf_basic(): "Content-Type": "application/json", "Authorization": "Bearer 1234567890", }, - json={"text": "hello", "metadata": "test"}, + json={"text": "hello", "metadata": "test", "__idx_in_batch": 0}, ) @@ -62,7 +62,7 @@ async def test_http_request_udf_with_qps(): qps=2, ) - batch = [{"text": "hello1"}, {"text": "hello2"}] + batch = {"__data": [{"text": "hello1"}, {"text": "hello2"}]} with patch("aiohttp.ClientSession") as mock_session_cls, patch( "time.time" @@ -80,7 +80,7 @@ async def test_http_request_udf_with_qps(): mock_time.side_effect = [0, 0.1, 0.2] results = [] - async for result in udf.udf(batch): + async for result in udf(batch): results.append(result) assert len(results) == 2 diff --git a/python/ray/llm/tests/batch/stages/test_tokenize_stage.py b/python/ray/llm/tests/batch/stages/test_tokenize_stage.py index 8e940e10ebc0..b3ccee1cdc94 100644 --- a/python/ray/llm/tests/batch/stages/test_tokenize_stage.py +++ b/python/ray/llm/tests/batch/stages/test_tokenize_stage.py @@ -26,11 +26,11 @@ async def test_tokenize_udf_basic(mock_tokenizer_setup): ] udf = TokenizeUDF(data_column="__data", model="test-model") - batch = [{"prompt": "Hello"}, {"prompt": "World"}] + batch = {"__data": [{"prompt": "Hello"}, {"prompt": "World"}]} results = [] - async for result in udf.udf(batch): - results.append(result) + async for result in udf(batch): + results.append(result["__data"][0]) assert len(results) == 2 assert all(result["tokenized_prompt"] == [1, 2, 3] for result in results) @@ -46,14 +46,16 @@ async def test_detokenize_udf_basic(mock_tokenizer_setup): mock_tokenizer.batch_decode.return_value = ["Hello", "World"] udf = DetokenizeUDF(data_column="__data", model="test-model") - batch = [ - {"generated_tokens": [1, 2, 3]}, - {"generated_tokens": [4, 5, 6]}, - ] + batch = { + "__data": [ + {"generated_tokens": [1, 2, 3]}, + {"generated_tokens": [4, 5, 6]}, + ] + } results = [] - async for result in udf.udf(batch): - results.append(result) + async for result in udf(batch): + results.append(result["__data"][0]) assert len(results) == 2 assert results[0]["generated_text"] == "Hello" From 198b83fe300921ad5355c22643bf074ee2416ec6 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 31 Jan 2025 17:37:00 -0800 Subject: [PATCH 3/6] fix Signed-off-by: Cody Yu --- python/ray/llm/_internal/batch/stages/base.py | 1 + .../batch/stages/prepare_image_stage.py | 28 +++++-- .../batch/stages/test_prepare_image_stage.py | 84 ++++++++++--------- 3 files changed, 66 insertions(+), 47 deletions(-) diff --git a/python/ray/llm/_internal/batch/stages/base.py b/python/ray/llm/_internal/batch/stages/base.py index 8e172541e560..666156af82d1 100644 --- a/python/ray/llm/_internal/batch/stages/base.py +++ b/python/ray/llm/_internal/batch/stages/base.py @@ -175,6 +175,7 @@ async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]] is_outputed[idx_in_batch] = True # Add stage outputs to the data column of the row. + inputs[idx_in_batch].pop(self.idx_in_batch_column) inputs[idx_in_batch].update(output) yield {self.data_column: [inputs[idx_in_batch]]} diff --git a/python/ray/llm/_internal/batch/stages/prepare_image_stage.py b/python/ray/llm/_internal/batch/stages/prepare_image_stage.py index 88641a8343ab..4c933cd96b83 100644 --- a/python/ray/llm/_internal/batch/stages/prepare_image_stage.py +++ b/python/ray/llm/_internal/batch/stages/prepare_image_stage.py @@ -342,18 +342,30 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] flat_all_image_info = [img for imgs in all_image_info for img in imgs] flat_all_images = await self.image_processor.process(flat_all_image_info) - idx = 0 + # TODO: We now use asyncio.gather to process all images in this batch, + # so the outputs here must be in order. However, it is more efficient + # to support out-of-order outputs so that we won't be blocked by slow + # downloaded images. + img_start_idx = 0 + idx_in_batch = 0 for image_info_per_req in all_image_info: num_images_in_req = len(image_info_per_req) + ret = {self.idx_in_batch_column: idx_in_batch} if num_images_in_req == 0: - yield {} + yield ret else: - images = flat_all_images[idx : idx + num_images_in_req] - yield { - "image": images, - "image_sizes": [(img.width, img.height) for img in images], - } - idx += num_images_in_req + images = flat_all_images[ + img_start_idx : img_start_idx + num_images_in_req + ] + ret.update( + { + "image": images, + "image_sizes": [(img.width, img.height) for img in images], + } + ) + yield ret + img_start_idx += num_images_in_req + idx_in_batch += 1 @property def expected_input_keys(self) -> List[str]: diff --git a/python/ray/llm/tests/batch/stages/test_prepare_image_stage.py b/python/ray/llm/tests/batch/stages/test_prepare_image_stage.py index d91885c99ae0..0d6d14481eea 100644 --- a/python/ray/llm/tests/batch/stages/test_prepare_image_stage.py +++ b/python/ray/llm/tests/batch/stages/test_prepare_image_stage.py @@ -46,21 +46,23 @@ async def test_prepare_image_udf_basic(mock_image_processor, mock_image): udf = PrepareImageUDF(data_column="__data") # Test batch with one message containing an image URL - batch = [ - { - "messages": [ - { - "content": [ - {"type": "image", "image": "http://example.com/image.jpg"} - ] - } - ] - } - ] + batch = { + "__data": [ + { + "messages": [ + { + "content": [ + {"type": "image", "image": "http://example.com/image.jpg"} + ] + } + ] + } + ] + } results = [] - async for result in udf.udf(batch): - results.append(result) + async for result in udf(batch): + results.append(result["__data"][0]) assert len(results) == 1 assert "image" in results[0] @@ -74,22 +76,24 @@ async def test_prepare_image_udf_multiple_images(mock_image_processor, mock_imag udf = PrepareImageUDF(data_column="__data") # Test batch with multiple images in one message - batch = [ - { - "messages": [ - { - "content": [ - {"type": "image", "image": "http://example.com/image1.jpg"}, - {"type": "image", "image": "http://example.com/image2.jpg"}, - ] - } - ] - } - ] + batch = { + "__data": [ + { + "messages": [ + { + "content": [ + {"type": "image", "image": "http://example.com/image1.jpg"}, + {"type": "image", "image": "http://example.com/image2.jpg"}, + ] + } + ] + } + ] + } results = [] - async for result in udf.udf(batch): - results.append(result) + async for result in udf(batch): + results.append(result["__data"][0]) assert len(results) == 1 assert len(results[0]["image"]) == 2 @@ -101,14 +105,14 @@ async def test_prepare_image_udf_no_images(mock_image_processor): udf = PrepareImageUDF(data_column="__data") # Test batch with no images - batch = [{"messages": [{"content": "Hello, world!"}]}] + batch = {"__data": [{"messages": [{"content": "Hello, world!"}]}]} results = [] - async for result in udf.udf(batch): - results.append(result) + async for result in udf(batch): + results.append(result["__data"][0]) assert len(results) == 1 - assert results[0] == {} + assert results[0] == {"messages": [{"content": "Hello, world!"}]} @pytest.mark.asyncio @@ -143,14 +147,16 @@ async def test_prepare_image_udf_invalid_image_type(mock_image_processor): udf = PrepareImageUDF(data_column="__data") # Test batch with invalid image type - batch = [ - { - "messages": [ - {"content": [{"type": "image", "image": 123}]} # Invalid image type - ] - } - ] + batch = { + "__data": [ + { + "messages": [ + {"content": [{"type": "image", "image": 123}]} # Invalid image type + ] + } + ] + } with pytest.raises(ValueError, match="Cannot handle image type"): - async for _ in udf.udf(batch): + async for _ in udf(batch): pass From ddd1d759b7fa68daa9cb0cdca6832b29431e7ab6 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 3 Feb 2025 15:12:29 -0800 Subject: [PATCH 4/6] fix Signed-off-by: Cody Yu --- python/ray/llm/tests/batch/processor/test_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/ray/llm/tests/batch/processor/test_base.py b/python/ray/llm/tests/batch/processor/test_base.py index b186117ac551..f9db4156934e 100644 --- a/python/ray/llm/tests/batch/processor/test_base.py +++ b/python/ray/llm/tests/batch/processor/test_base.py @@ -59,6 +59,7 @@ async def udf( yield { # Use the same name to chain multiple dummy stages. "val": answer, + self.idx_in_batch_column: row[self.idx_in_batch_column], } @property From b8880a2c0137679b14c650eb96352158c3a74e4d Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 4 Feb 2025 15:13:57 -0800 Subject: [PATCH 5/6] comments Signed-off-by: Cody Yu --- python/ray/llm/_internal/batch/stages/base.py | 43 ++++++++++--------- .../batch/stages/chat_template_stage.py | 16 +++---- .../batch/stages/prepare_image_stage.py | 8 ++-- 3 files changed, 33 insertions(+), 34 deletions(-) diff --git a/python/ray/llm/_internal/batch/stages/base.py b/python/ray/llm/_internal/batch/stages/base.py index 666156af82d1..f7a55266adf2 100644 --- a/python/ray/llm/_internal/batch/stages/base.py +++ b/python/ray/llm/_internal/batch/stages/base.py @@ -159,7 +159,7 @@ async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]] # for 2 batches (63 + 63 > 64) to continue proceeding. On the other hand, # if we stream outputs one-by-one, Ray Data can form a batch of 64 before # the second batch is done. - is_outputed = [False] * len(inputs) + is_outputed = set(range(len(inputs))) async for output in self.udf(inputs): if self.idx_in_batch_column not in output: raise ValueError( @@ -167,21 +167,20 @@ async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]] f"{self.idx_in_batch_column}." ) idx_in_batch = output.pop(self.idx_in_batch_column) - if is_outputed[idx_in_batch]: + if idx_in_batch not in is_outputed: raise ValueError( f"The row {idx_in_batch} is outputed twice. " "This is likely due to the UDF is not one-to-one." ) - is_outputed[idx_in_batch] = True + is_outputed.remove(idx_in_batch) # Add stage outputs to the data column of the row. inputs[idx_in_batch].pop(self.idx_in_batch_column) inputs[idx_in_batch].update(output) yield {self.data_column: [inputs[idx_in_batch]]} - if not all(is_outputed): - missing_rows = [i for i, o in enumerate(is_outputed) if not o] - raise ValueError(f"The rows {missing_rows} are not outputed.") + if is_outputed: + raise ValueError(f"The rows {is_outputed} are not outputed.") def validate_inputs(self, inputs: List[Dict[str, Any]]): """Validate the inputs to make sure the required keys are present. @@ -192,24 +191,26 @@ def validate_inputs(self, inputs: List[Dict[str, Any]]): Raises: ValueError: If the required keys are not found. """ - input_keys = set(inputs[0].keys()) + expected_input_keys = set(self.expected_input_keys) - if self.idx_in_batch_column in input_keys: - raise ValueError( - f"The input column {self.idx_in_batch_column} is reserved " - "for internal use." - ) + for inp in inputs: + input_keys = set(inp.keys()) - expected_input_keys = self.expected_input_keys - if not expected_input_keys: - return + if self.idx_in_batch_column in input_keys: + raise ValueError( + f"The input column {self.idx_in_batch_column} is reserved " + "for internal use." + ) - missing_required = set(expected_input_keys) - input_keys - if missing_required: - raise ValueError( - f"Required input keys {missing_required} not found at the input of " - f"{self.__class__.__name__}. Input keys: {input_keys}" - ) + if not expected_input_keys: + continue + + missing_required = expected_input_keys - input_keys + if missing_required: + raise ValueError( + f"Required input keys {missing_required} not found at the input of " + f"{self.__class__.__name__}. Input keys: {input_keys}" + ) @property def expected_input_keys(self) -> List[str]: diff --git a/python/ray/llm/_internal/batch/stages/chat_template_stage.py b/python/ray/llm/_internal/batch/stages/chat_template_stage.py index 1995d45b42c2..3069d4faed27 100644 --- a/python/ray/llm/_internal/batch/stages/chat_template_stage.py +++ b/python/ray/llm/_internal/batch/stages/chat_template_stage.py @@ -37,14 +37,14 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] Yields: A generator of rows with the chat template applied. """ - for row, prompt in zip( - batch, - self.tokenizer.apply_chat_template( - [row["messages"].tolist() for row in batch], - tokenize=False, - add_generation_prompt=True, - ), - ): + prompts = self.tokenizer.apply_chat_template( + [row["messages"].tolist() for row in batch], + tokenize=False, + add_generation_prompt=True, + ) + assert len(batch) == len(prompts) + + for row, prompt in zip(batch, prompts): yield { self.idx_in_batch_column: row[self.idx_in_batch_column], "prompt": prompt, diff --git a/python/ray/llm/_internal/batch/stages/prepare_image_stage.py b/python/ray/llm/_internal/batch/stages/prepare_image_stage.py index 4c933cd96b83..d7e885fe958f 100644 --- a/python/ray/llm/_internal/batch/stages/prepare_image_stage.py +++ b/python/ray/llm/_internal/batch/stages/prepare_image_stage.py @@ -351,9 +351,8 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] for image_info_per_req in all_image_info: num_images_in_req = len(image_info_per_req) ret = {self.idx_in_batch_column: idx_in_batch} - if num_images_in_req == 0: - yield ret - else: + idx_in_batch += 1 + if num_images_in_req > 0: images = flat_all_images[ img_start_idx : img_start_idx + num_images_in_req ] @@ -363,9 +362,8 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] "image_sizes": [(img.width, img.height) for img in images], } ) - yield ret img_start_idx += num_images_in_req - idx_in_batch += 1 + yield ret @property def expected_input_keys(self) -> List[str]: From 3dd3442dcd9d7396d495b05112856fa2961a6ea4 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 5 Feb 2025 10:54:34 -0800 Subject: [PATCH 6/6] comment Signed-off-by: Cody Yu --- python/ray/llm/_internal/batch/stages/base.py | 30 +++++++++++-------- .../batch/stages/chat_template_stage.py | 2 +- .../batch/stages/http_request_stage.py | 8 ++--- .../batch/stages/prepare_image_stage.py | 2 +- .../_internal/batch/stages/tokenize_stage.py | 4 +-- .../llm/tests/batch/processor/test_base.py | 2 +- .../ray/llm/tests/batch/stages/test_base.py | 2 +- 7 files changed, 27 insertions(+), 23 deletions(-) diff --git a/python/ray/llm/_internal/batch/stages/base.py b/python/ray/llm/_internal/batch/stages/base.py index f7a55266adf2..cee80cab50b9 100644 --- a/python/ray/llm/_internal/batch/stages/base.py +++ b/python/ray/llm/_internal/batch/stages/base.py @@ -75,7 +75,7 @@ class StatefulStageUDF: # The internal column name for the index of the row in the batch. # This is used to align the output of the UDF with the input batch. - idx_in_batch_column: str = "__idx_in_batch" + IDX_IN_BATCH_COLUMN: str = "__idx_in_batch" def __init__(self, data_column: str): self.data_column = data_column @@ -148,8 +148,12 @@ async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]] self.validate_inputs(inputs) # Assign the index of the row in the batch to the idx_in_batch_column. + # This is beacuse the UDF output may be out-of-order (if asyncio.as_completed + # is used interanlly for example), and we need to carry over unused input + # columns to the next stage. Thus, we use the row index in batch to match + # the output of the UDF with the input. for idx, row in enumerate(inputs): - row[self.idx_in_batch_column] = idx + row[self.IDX_IN_BATCH_COLUMN] = idx # Always stream the outputs one by one to better overlapping # batches. For example, when the output batch size is 64, Ray Data @@ -159,28 +163,28 @@ async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]] # for 2 batches (63 + 63 > 64) to continue proceeding. On the other hand, # if we stream outputs one-by-one, Ray Data can form a batch of 64 before # the second batch is done. - is_outputed = set(range(len(inputs))) + not_outputed_rows = set(range(len(inputs))) async for output in self.udf(inputs): - if self.idx_in_batch_column not in output: + if self.IDX_IN_BATCH_COLUMN not in output: raise ValueError( "The output of the UDF must contain the column " - f"{self.idx_in_batch_column}." + f"{self.IDX_IN_BATCH_COLUMN}." ) - idx_in_batch = output.pop(self.idx_in_batch_column) - if idx_in_batch not in is_outputed: + idx_in_batch = output.pop(self.IDX_IN_BATCH_COLUMN) + if idx_in_batch not in not_outputed_rows: raise ValueError( f"The row {idx_in_batch} is outputed twice. " "This is likely due to the UDF is not one-to-one." ) - is_outputed.remove(idx_in_batch) + not_outputed_rows.remove(idx_in_batch) # Add stage outputs to the data column of the row. - inputs[idx_in_batch].pop(self.idx_in_batch_column) + inputs[idx_in_batch].pop(self.IDX_IN_BATCH_COLUMN) inputs[idx_in_batch].update(output) yield {self.data_column: [inputs[idx_in_batch]]} - if is_outputed: - raise ValueError(f"The rows {is_outputed} are not outputed.") + if not_outputed_rows: + raise ValueError(f"The rows {not_outputed_rows} are not outputed.") def validate_inputs(self, inputs: List[Dict[str, Any]]): """Validate the inputs to make sure the required keys are present. @@ -196,9 +200,9 @@ def validate_inputs(self, inputs: List[Dict[str, Any]]): for inp in inputs: input_keys = set(inp.keys()) - if self.idx_in_batch_column in input_keys: + if self.IDX_IN_BATCH_COLUMN in input_keys: raise ValueError( - f"The input column {self.idx_in_batch_column} is reserved " + f"The input column {self.IDX_IN_BATCH_COLUMN} is reserved " "for internal use." ) diff --git a/python/ray/llm/_internal/batch/stages/chat_template_stage.py b/python/ray/llm/_internal/batch/stages/chat_template_stage.py index 3069d4faed27..bb760dc8f69c 100644 --- a/python/ray/llm/_internal/batch/stages/chat_template_stage.py +++ b/python/ray/llm/_internal/batch/stages/chat_template_stage.py @@ -46,7 +46,7 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] for row, prompt in zip(batch, prompts): yield { - self.idx_in_batch_column: row[self.idx_in_batch_column], + self.IDX_IN_BATCH_COLUMN: row[self.IDX_IN_BATCH_COLUMN], "prompt": prompt, } diff --git a/python/ray/llm/_internal/batch/stages/http_request_stage.py b/python/ray/llm/_internal/batch/stages/http_request_stage.py index 20b6083efda3..2b406cd08526 100644 --- a/python/ray/llm/_internal/batch/stages/http_request_stage.py +++ b/python/ray/llm/_internal/batch/stages/http_request_stage.py @@ -74,19 +74,19 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] headers=headers, json=json_body, ) - pending_requests.append((row[self.idx_in_batch_column], request)) + pending_requests.append((row[self.IDX_IN_BATCH_COLUMN], request)) # Now receive all responses for idx_in_batch, request in pending_requests: async with await request as response: resp_json = await response.json() - if self.idx_in_batch_column in resp_json: + if self.IDX_IN_BATCH_COLUMN in resp_json: raise ValueError( "The response of the HTTP request must not contain " - f"the column {self.idx_in_batch_column}." + f"the column {self.IDX_IN_BATCH_COLUMN}." ) yield { - self.idx_in_batch_column: idx_in_batch, + self.IDX_IN_BATCH_COLUMN: idx_in_batch, **resp_json, } diff --git a/python/ray/llm/_internal/batch/stages/prepare_image_stage.py b/python/ray/llm/_internal/batch/stages/prepare_image_stage.py index d7e885fe958f..c8ca29c6f5d4 100644 --- a/python/ray/llm/_internal/batch/stages/prepare_image_stage.py +++ b/python/ray/llm/_internal/batch/stages/prepare_image_stage.py @@ -350,7 +350,7 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] idx_in_batch = 0 for image_info_per_req in all_image_info: num_images_in_req = len(image_info_per_req) - ret = {self.idx_in_batch_column: idx_in_batch} + ret = {self.IDX_IN_BATCH_COLUMN: idx_in_batch} idx_in_batch += 1 if num_images_in_req > 0: images = flat_all_images[ diff --git a/python/ray/llm/_internal/batch/stages/tokenize_stage.py b/python/ray/llm/_internal/batch/stages/tokenize_stage.py index b686e81ac75c..64c9a51220b0 100644 --- a/python/ray/llm/_internal/batch/stages/tokenize_stage.py +++ b/python/ray/llm/_internal/batch/stages/tokenize_stage.py @@ -42,7 +42,7 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] self.tokenizer([row["prompt"] for row in batch])["input_ids"], ): yield { - self.idx_in_batch_column: row[self.idx_in_batch_column], + self.IDX_IN_BATCH_COLUMN: row[self.IDX_IN_BATCH_COLUMN], "tokenized_prompt": prompt_token_ids, } @@ -100,7 +100,7 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any] ), ): yield { - self.idx_in_batch_column: row[self.idx_in_batch_column], + self.IDX_IN_BATCH_COLUMN: row[self.IDX_IN_BATCH_COLUMN], "generated_text": generated_text, } diff --git a/python/ray/llm/tests/batch/processor/test_base.py b/python/ray/llm/tests/batch/processor/test_base.py index f9db4156934e..b684f0891f61 100644 --- a/python/ray/llm/tests/batch/processor/test_base.py +++ b/python/ray/llm/tests/batch/processor/test_base.py @@ -59,7 +59,7 @@ async def udf( yield { # Use the same name to chain multiple dummy stages. "val": answer, - self.idx_in_batch_column: row[self.idx_in_batch_column], + self.IDX_IN_BATCH_COLUMN: row[self.IDX_IN_BATCH_COLUMN], } @property diff --git a/python/ray/llm/tests/batch/stages/test_base.py b/python/ray/llm/tests/batch/stages/test_base.py index 0e98a33b6425..cc1be67874c9 100644 --- a/python/ray/llm/tests/batch/stages/test_base.py +++ b/python/ray/llm/tests/batch/stages/test_base.py @@ -58,7 +58,7 @@ async def udf( for row in rows[::-1]: ret = {"processed": row["value"] * 2} if not self.udf_output_missing_idx_in_batch_column: - ret[self.idx_in_batch_column] = row[self.idx_in_batch_column] + ret[self.IDX_IN_BATCH_COLUMN] = row[self.IDX_IN_BATCH_COLUMN] yield ret @property