diff --git a/python/ray/llm/_internal/batch/stages/base.py b/python/ray/llm/_internal/batch/stages/base.py
index 4248a88c7b102..666156af82d10 100644
--- a/python/ray/llm/_internal/batch/stages/base.py
+++ b/python/ray/llm/_internal/batch/stages/base.py
@@ -73,6 +73,10 @@ class StatefulStageUDF:
column for the next stage.
"""
+ # The internal column name for the index of the row in the batch.
+ # This is used to align the output of the UDF with the input batch.
+ idx_in_batch_column: str = "__idx_in_batch"
+
def __init__(self, data_column: str):
self.data_column = data_column
@@ -143,6 +147,10 @@ async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]]
inputs = inputs.tolist()
self.validate_inputs(inputs)
+ # Assign the index of the row in the batch to the idx_in_batch_column.
+ for idx, row in enumerate(inputs):
+ row[self.idx_in_batch_column] = idx
+
# Always stream the outputs one by one to better overlapping
# batches. For example, when the output batch size is 64, Ray Data
# will collect 64 outputs, and 1) send the batch of 64 to the next stage,
@@ -151,12 +159,29 @@ async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]]
# for 2 batches (63 + 63 > 64) to continue proceeding. On the other hand,
# if we stream outputs one-by-one, Ray Data can form a batch of 64 before
# the second batch is done.
- idx = 0
+ is_outputed = [False] * len(inputs)
async for output in self.udf(inputs):
+ if self.idx_in_batch_column not in output:
+ raise ValueError(
+ "The output of the UDF must contain the column "
+ f"{self.idx_in_batch_column}."
+ )
+ idx_in_batch = output.pop(self.idx_in_batch_column)
+ if is_outputed[idx_in_batch]:
+ raise ValueError(
+ f"The row {idx_in_batch} is outputed twice. "
+ "This is likely due to the UDF is not one-to-one."
+ )
+ is_outputed[idx_in_batch] = True
+
# Add stage outputs to the data column of the row.
- inputs[idx].update(output)
- yield {self.data_column: [inputs[idx]]}
- idx += 1
+ inputs[idx_in_batch].pop(self.idx_in_batch_column)
+ inputs[idx_in_batch].update(output)
+ yield {self.data_column: [inputs[idx_in_batch]]}
+
+ if not all(is_outputed):
+ missing_rows = [i for i, o in enumerate(is_outputed) if not o]
+ raise ValueError(f"The rows {missing_rows} are not outputed.")
def validate_inputs(self, inputs: List[Dict[str, Any]]):
"""Validate the inputs to make sure the required keys are present.
@@ -167,11 +192,18 @@ def validate_inputs(self, inputs: List[Dict[str, Any]]):
Raises:
ValueError: If the required keys are not found.
"""
+ input_keys = set(inputs[0].keys())
+
+ if self.idx_in_batch_column in input_keys:
+ raise ValueError(
+ f"The input column {self.idx_in_batch_column} is reserved "
+ "for internal use."
+ )
+
expected_input_keys = self.expected_input_keys
if not expected_input_keys:
return
- input_keys = set(inputs[0].keys())
missing_required = set(expected_input_keys) - input_keys
if missing_required:
raise ValueError(
diff --git a/python/ray/llm/_internal/batch/stages/chat_template_stage.py b/python/ray/llm/_internal/batch/stages/chat_template_stage.py
index f41ee1aaec159..1995d45b42c23 100644
--- a/python/ray/llm/_internal/batch/stages/chat_template_stage.py
+++ b/python/ray/llm/_internal/batch/stages/chat_template_stage.py
@@ -37,12 +37,18 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]
Yields:
A generator of rows with the chat template applied.
"""
- for prompt in self.tokenizer.apply_chat_template(
- [row["messages"].tolist() for row in batch],
- tokenize=False,
- add_generation_prompt=True,
+ for row, prompt in zip(
+ batch,
+ self.tokenizer.apply_chat_template(
+ [row["messages"].tolist() for row in batch],
+ tokenize=False,
+ add_generation_prompt=True,
+ ),
):
- yield {"prompt": prompt}
+ yield {
+ self.idx_in_batch_column: row[self.idx_in_batch_column],
+ "prompt": prompt,
+ }
@property
def expected_input_keys(self) -> List[str]:
diff --git a/python/ray/llm/_internal/batch/stages/http_request_stage.py b/python/ray/llm/_internal/batch/stages/http_request_stage.py
index 9b765ec2b8d02..20b6083efda36 100644
--- a/python/ray/llm/_internal/batch/stages/http_request_stage.py
+++ b/python/ray/llm/_internal/batch/stages/http_request_stage.py
@@ -74,12 +74,21 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]
headers=headers,
json=json_body,
)
- pending_requests.append(request)
+ pending_requests.append((row[self.idx_in_batch_column], request))
# Now receive all responses
- for request in pending_requests:
+ for idx_in_batch, request in pending_requests:
async with await request as response:
- yield await response.json()
+ resp_json = await response.json()
+ if self.idx_in_batch_column in resp_json:
+ raise ValueError(
+ "The response of the HTTP request must not contain "
+ f"the column {self.idx_in_batch_column}."
+ )
+ yield {
+ self.idx_in_batch_column: idx_in_batch,
+ **resp_json,
+ }
class HttpRequestStage(StatefulStage):
diff --git a/python/ray/llm/_internal/batch/stages/prepare_image_stage.py b/python/ray/llm/_internal/batch/stages/prepare_image_stage.py
index 88641a8343ab7..4c933cd96b838 100644
--- a/python/ray/llm/_internal/batch/stages/prepare_image_stage.py
+++ b/python/ray/llm/_internal/batch/stages/prepare_image_stage.py
@@ -342,18 +342,30 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]
flat_all_image_info = [img for imgs in all_image_info for img in imgs]
flat_all_images = await self.image_processor.process(flat_all_image_info)
- idx = 0
+ # TODO: We now use asyncio.gather to process all images in this batch,
+ # so the outputs here must be in order. However, it is more efficient
+ # to support out-of-order outputs so that we won't be blocked by slow
+ # downloaded images.
+ img_start_idx = 0
+ idx_in_batch = 0
for image_info_per_req in all_image_info:
num_images_in_req = len(image_info_per_req)
+ ret = {self.idx_in_batch_column: idx_in_batch}
if num_images_in_req == 0:
- yield {}
+ yield ret
else:
- images = flat_all_images[idx : idx + num_images_in_req]
- yield {
- "image": images,
- "image_sizes": [(img.width, img.height) for img in images],
- }
- idx += num_images_in_req
+ images = flat_all_images[
+ img_start_idx : img_start_idx + num_images_in_req
+ ]
+ ret.update(
+ {
+ "image": images,
+ "image_sizes": [(img.width, img.height) for img in images],
+ }
+ )
+ yield ret
+ img_start_idx += num_images_in_req
+ idx_in_batch += 1
@property
def expected_input_keys(self) -> List[str]:
diff --git a/python/ray/llm/_internal/batch/stages/tokenize_stage.py b/python/ray/llm/_internal/batch/stages/tokenize_stage.py
index 68e7bc53e3c0c..b686e81ac75cc 100644
--- a/python/ray/llm/_internal/batch/stages/tokenize_stage.py
+++ b/python/ray/llm/_internal/batch/stages/tokenize_stage.py
@@ -41,7 +41,10 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]
batch,
self.tokenizer([row["prompt"] for row in batch])["input_ids"],
):
- yield {"tokenized_prompt": prompt_token_ids}
+ yield {
+ self.idx_in_batch_column: row[self.idx_in_batch_column],
+ "tokenized_prompt": prompt_token_ids,
+ }
@property
def expected_input_keys(self) -> List[str]:
@@ -96,7 +99,10 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]
skip_special_tokens=True,
),
):
- yield {"generated_text": generated_text}
+ yield {
+ self.idx_in_batch_column: row[self.idx_in_batch_column],
+ "generated_text": generated_text,
+ }
@property
def expected_input_keys(self) -> List[str]:
diff --git a/python/ray/llm/tests/batch/stages/test_base.py b/python/ray/llm/tests/batch/stages/test_base.py
index c99754a23b598..0e98a33b64254 100644
--- a/python/ray/llm/tests/batch/stages/test_base.py
+++ b/python/ray/llm/tests/batch/stages/test_base.py
@@ -40,11 +40,26 @@ def to_string(x: dict) -> dict:
class TestStatefulStageUDF:
class SimpleUDF(StatefulStageUDF):
+ def __init__(
+ self,
+ data_column: str,
+ udf_output_missing_idx_in_batch_column: bool = False,
+ ):
+ super().__init__(data_column)
+ self.udf_output_missing_idx_in_batch_column = (
+ udf_output_missing_idx_in_batch_column
+ )
+
async def udf(
self, rows: list[Dict[str, Any]]
) -> AsyncIterator[Dict[str, Any]]:
- for row in rows:
- yield {"processed": row["value"] * 2}
+
+ # Intentionally output in a reversed order to test OOO.
+ for row in rows[::-1]:
+ ret = {"processed": row["value"] * 2}
+ if not self.udf_output_missing_idx_in_batch_column:
+ ret[self.idx_in_batch_column] = row[self.idx_in_batch_column]
+ yield ret
@property
def expected_input_keys(self) -> List[str]:
@@ -55,7 +70,7 @@ async def test_basic_processing(self):
udf = self.SimpleUDF(data_column="__data")
batch = {
- "__data": [{"value": 1, "extra": "a"}, {"value": 2, "extra": "b"}],
+ "__data": [{"value": 1, "extra": 10}, {"value": 2, "extra": 20}],
}
results = []
@@ -63,12 +78,12 @@ async def test_basic_processing(self):
results.append(result)
assert len(results) == 2
- assert results[0] == {
- "__data": [{"processed": 2, "value": 1, "extra": "a"}],
- }
- assert results[1] == {
- "__data": [{"processed": 4, "value": 2, "extra": "b"}],
- }
+ for result in results:
+ data = result["__data"][0]
+ val = data["value"]
+ assert data["processed"] == val * 2
+ assert data["extra"] == 10 * val
+ assert data["value"] == val
@pytest.mark.asyncio
async def test_missing_data_column(self):
@@ -90,6 +105,19 @@ async def test_missing_required_key(self):
async for _ in udf(batch):
pass
+ @pytest.mark.asyncio
+ async def test_missing_idx_in_batch_column(self):
+ udf = self.SimpleUDF(
+ data_column="__data",
+ udf_output_missing_idx_in_batch_column=True,
+ )
+
+ batch = {"__data": [{"value": 1}]}
+
+ with pytest.raises(ValueError):
+ async for _ in udf(batch):
+ pass
+
def test_stateful_stage():
udf = TestStatefulStageUDF.SimpleUDF(data_column="__data")
diff --git a/python/ray/llm/tests/batch/stages/test_chat_template_stage.py b/python/ray/llm/tests/batch/stages/test_chat_template_stage.py
index 4d822d6fcd439..f5a05a0fd93bd 100644
--- a/python/ray/llm/tests/batch/stages/test_chat_template_stage.py
+++ b/python/ray/llm/tests/batch/stages/test_chat_template_stage.py
@@ -24,25 +24,23 @@ async def test_chat_template_udf_basic(mock_tokenizer_setup):
model="test-model",
)
- batch = [
- {
- "messages": MagicMock(
- tolist=lambda: [{"role": "user", "content": "Hello AI"}]
- )
- }
- ]
+ batch = {
+ "__data": [
+ {
+ "messages": MagicMock(
+ tolist=lambda: [{"role": "user", "content": "Hello AI"}]
+ )
+ }
+ ]
+ }
results = []
- async for result in udf.udf(batch):
+ async for result in udf(batch):
results.append(result)
assert len(results) == 1
- assert results[0] == {"prompt": "Hello AI"}
- mock_tokenizer.apply_chat_template.assert_called_once_with(
- [batch[0]["messages"].tolist()],
- tokenize=False,
- add_generation_prompt=True,
- )
+ assert results[0]["__data"][0]["prompt"] == "Hello AI"
+ mock_tokenizer.apply_chat_template.assert_called_once()
@pytest.mark.asyncio
@@ -58,31 +56,29 @@ async def test_chat_template_udf_multiple_messages(mock_tokenizer_setup):
model="test-model",
)
- batch = [
- {
- "messages": MagicMock(
- tolist=lambda: [{"role": "user", "content": "Hello AI"}]
- )
- },
- {
- "messages": MagicMock(
- tolist=lambda: [{"role": "user", "content": "How are you?"}]
- )
- },
- ]
+ batch = {
+ "__data": [
+ {
+ "messages": MagicMock(
+ tolist=lambda: [{"role": "user", "content": "Hello AI"}]
+ )
+ },
+ {
+ "messages": MagicMock(
+ tolist=lambda: [{"role": "user", "content": "How are you?"}],
+ )
+ },
+ ]
+ }
results = []
- async for result in udf.udf(batch):
+ async for result in udf(batch):
results.append(result)
assert len(results) == 2
- assert results[0] == {"prompt": "Hello AI"}
- assert results[1] == {"prompt": "How are you?"}
- mock_tokenizer.apply_chat_template.assert_called_once_with(
- [msg["messages"].tolist() for msg in batch],
- tokenize=False,
- add_generation_prompt=True,
- )
+ assert results[0]["__data"][0]["prompt"] == "Hello AI"
+ assert results[1]["__data"][0]["prompt"] == "How are you?"
+ mock_tokenizer.apply_chat_template.assert_called_once()
def test_chat_template_udf_expected_input_keys(mock_tokenizer_setup):
diff --git a/python/ray/llm/tests/batch/stages/test_http_request_stage.py b/python/ray/llm/tests/batch/stages/test_http_request_stage.py
index 636979f893ac6..2477399bb0167 100644
--- a/python/ray/llm/tests/batch/stages/test_http_request_stage.py
+++ b/python/ray/llm/tests/batch/stages/test_http_request_stage.py
@@ -32,7 +32,7 @@ async def test_http_request_udf_basic():
qps=None,
)
- batch = [{"text": "hello", "metadata": "test"}]
+ batch = {"__data": [{"text": "hello", "metadata": "test"}]}
with patch("aiohttp.ClientSession") as mock_session_cls:
session = AsyncMock()
@@ -41,8 +41,8 @@ async def test_http_request_udf_basic():
)
mock_session_cls.return_value.__aenter__.return_value = session
- async for result in udf.udf(batch):
- assert result == {"response": "test"}
+ async for result in udf(batch):
+ assert result["__data"][0]["response"] == "test"
session.post.assert_called_once_with(
"http://test.com/api",
@@ -50,7 +50,7 @@ async def test_http_request_udf_basic():
"Content-Type": "application/json",
"Authorization": "Bearer 1234567890",
},
- json={"text": "hello", "metadata": "test"},
+ json={"text": "hello", "metadata": "test", "__idx_in_batch": 0},
)
@@ -62,7 +62,7 @@ async def test_http_request_udf_with_qps():
qps=2,
)
- batch = [{"text": "hello1"}, {"text": "hello2"}]
+ batch = {"__data": [{"text": "hello1"}, {"text": "hello2"}]}
with patch("aiohttp.ClientSession") as mock_session_cls, patch(
"time.time"
@@ -80,7 +80,7 @@ async def test_http_request_udf_with_qps():
mock_time.side_effect = [0, 0.1, 0.2]
results = []
- async for result in udf.udf(batch):
+ async for result in udf(batch):
results.append(result)
assert len(results) == 2
diff --git a/python/ray/llm/tests/batch/stages/test_prepare_image_stage.py b/python/ray/llm/tests/batch/stages/test_prepare_image_stage.py
index d91885c99ae0f..0d6d14481eea8 100644
--- a/python/ray/llm/tests/batch/stages/test_prepare_image_stage.py
+++ b/python/ray/llm/tests/batch/stages/test_prepare_image_stage.py
@@ -46,21 +46,23 @@ async def test_prepare_image_udf_basic(mock_image_processor, mock_image):
udf = PrepareImageUDF(data_column="__data")
# Test batch with one message containing an image URL
- batch = [
- {
- "messages": [
- {
- "content": [
- {"type": "image", "image": "http://example.com/image.jpg"}
- ]
- }
- ]
- }
- ]
+ batch = {
+ "__data": [
+ {
+ "messages": [
+ {
+ "content": [
+ {"type": "image", "image": "http://example.com/image.jpg"}
+ ]
+ }
+ ]
+ }
+ ]
+ }
results = []
- async for result in udf.udf(batch):
- results.append(result)
+ async for result in udf(batch):
+ results.append(result["__data"][0])
assert len(results) == 1
assert "image" in results[0]
@@ -74,22 +76,24 @@ async def test_prepare_image_udf_multiple_images(mock_image_processor, mock_imag
udf = PrepareImageUDF(data_column="__data")
# Test batch with multiple images in one message
- batch = [
- {
- "messages": [
- {
- "content": [
- {"type": "image", "image": "http://example.com/image1.jpg"},
- {"type": "image", "image": "http://example.com/image2.jpg"},
- ]
- }
- ]
- }
- ]
+ batch = {
+ "__data": [
+ {
+ "messages": [
+ {
+ "content": [
+ {"type": "image", "image": "http://example.com/image1.jpg"},
+ {"type": "image", "image": "http://example.com/image2.jpg"},
+ ]
+ }
+ ]
+ }
+ ]
+ }
results = []
- async for result in udf.udf(batch):
- results.append(result)
+ async for result in udf(batch):
+ results.append(result["__data"][0])
assert len(results) == 1
assert len(results[0]["image"]) == 2
@@ -101,14 +105,14 @@ async def test_prepare_image_udf_no_images(mock_image_processor):
udf = PrepareImageUDF(data_column="__data")
# Test batch with no images
- batch = [{"messages": [{"content": "Hello, world!"}]}]
+ batch = {"__data": [{"messages": [{"content": "Hello, world!"}]}]}
results = []
- async for result in udf.udf(batch):
- results.append(result)
+ async for result in udf(batch):
+ results.append(result["__data"][0])
assert len(results) == 1
- assert results[0] == {}
+ assert results[0] == {"messages": [{"content": "Hello, world!"}]}
@pytest.mark.asyncio
@@ -143,14 +147,16 @@ async def test_prepare_image_udf_invalid_image_type(mock_image_processor):
udf = PrepareImageUDF(data_column="__data")
# Test batch with invalid image type
- batch = [
- {
- "messages": [
- {"content": [{"type": "image", "image": 123}]} # Invalid image type
- ]
- }
- ]
+ batch = {
+ "__data": [
+ {
+ "messages": [
+ {"content": [{"type": "image", "image": 123}]} # Invalid image type
+ ]
+ }
+ ]
+ }
with pytest.raises(ValueError, match="Cannot handle image type"):
- async for _ in udf.udf(batch):
+ async for _ in udf(batch):
pass
diff --git a/python/ray/llm/tests/batch/stages/test_tokenize_stage.py b/python/ray/llm/tests/batch/stages/test_tokenize_stage.py
index 8e940e10ebc0e..b3ccee1cdc945 100644
--- a/python/ray/llm/tests/batch/stages/test_tokenize_stage.py
+++ b/python/ray/llm/tests/batch/stages/test_tokenize_stage.py
@@ -26,11 +26,11 @@ async def test_tokenize_udf_basic(mock_tokenizer_setup):
]
udf = TokenizeUDF(data_column="__data", model="test-model")
- batch = [{"prompt": "Hello"}, {"prompt": "World"}]
+ batch = {"__data": [{"prompt": "Hello"}, {"prompt": "World"}]}
results = []
- async for result in udf.udf(batch):
- results.append(result)
+ async for result in udf(batch):
+ results.append(result["__data"][0])
assert len(results) == 2
assert all(result["tokenized_prompt"] == [1, 2, 3] for result in results)
@@ -46,14 +46,16 @@ async def test_detokenize_udf_basic(mock_tokenizer_setup):
mock_tokenizer.batch_decode.return_value = ["Hello", "World"]
udf = DetokenizeUDF(data_column="__data", model="test-model")
- batch = [
- {"generated_tokens": [1, 2, 3]},
- {"generated_tokens": [4, 5, 6]},
- ]
+ batch = {
+ "__data": [
+ {"generated_tokens": [1, 2, 3]},
+ {"generated_tokens": [4, 5, 6]},
+ ]
+ }
results = []
- async for result in udf.udf(batch):
- results.append(result)
+ async for result in udf(batch):
+ results.append(result["__data"][0])
assert len(results) == 2
assert results[0]["generated_text"] == "Hello"