Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLM Batch] Support out-of-order UDF outputs #50169

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions python/ray/llm/_internal/batch/stages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ class StatefulStageUDF:
column for the next stage.
"""

# The internal column name for the index of the row in the batch.
# This is used to align the output of the UDF with the input batch.
idx_in_batch_column: str = "__idx_in_batch"

def __init__(self, data_column: str):
self.data_column = data_column

Expand Down Expand Up @@ -143,6 +147,10 @@ async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]]
inputs = inputs.tolist()
self.validate_inputs(inputs)

# Assign the index of the row in the batch to the idx_in_batch_column.
for idx, row in enumerate(inputs):
row[self.idx_in_batch_column] = idx

# Always stream the outputs one by one to better overlapping
# batches. For example, when the output batch size is 64, Ray Data
# will collect 64 outputs, and 1) send the batch of 64 to the next stage,
Expand All @@ -151,12 +159,29 @@ async def __call__(self, batch: Dict[str, Any]) -> AsyncIterator[Dict[str, Any]]
# for 2 batches (63 + 63 > 64) to continue proceeding. On the other hand,
# if we stream outputs one-by-one, Ray Data can form a batch of 64 before
# the second batch is done.
idx = 0
is_outputed = [False] * len(inputs)
async for output in self.udf(inputs):
if self.idx_in_batch_column not in output:
raise ValueError(
"The output of the UDF must contain the column "
f"{self.idx_in_batch_column}."
)
idx_in_batch = output.pop(self.idx_in_batch_column)
if is_outputed[idx_in_batch]:
raise ValueError(
f"The row {idx_in_batch} is outputed twice. "
"This is likely due to the UDF is not one-to-one."
)
is_outputed[idx_in_batch] = True

# Add stage outputs to the data column of the row.
inputs[idx].update(output)
yield {self.data_column: [inputs[idx]]}
idx += 1
inputs[idx_in_batch].pop(self.idx_in_batch_column)
inputs[idx_in_batch].update(output)
yield {self.data_column: [inputs[idx_in_batch]]}

if not all(is_outputed):
missing_rows = [i for i, o in enumerate(is_outputed) if not o]
raise ValueError(f"The rows {missing_rows} are not outputed.")

def validate_inputs(self, inputs: List[Dict[str, Any]]):
"""Validate the inputs to make sure the required keys are present.
Expand All @@ -167,11 +192,18 @@ def validate_inputs(self, inputs: List[Dict[str, Any]]):
Raises:
ValueError: If the required keys are not found.
"""
input_keys = set(inputs[0].keys())

if self.idx_in_batch_column in input_keys:
raise ValueError(
f"The input column {self.idx_in_batch_column} is reserved "
"for internal use."
)

expected_input_keys = self.expected_input_keys
if not expected_input_keys:
return

input_keys = set(inputs[0].keys())
missing_required = set(expected_input_keys) - input_keys
if missing_required:
raise ValueError(
Expand Down
16 changes: 11 additions & 5 deletions python/ray/llm/_internal/batch/stages/chat_template_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,18 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]
Yields:
A generator of rows with the chat template applied.
"""
for prompt in self.tokenizer.apply_chat_template(
[row["messages"].tolist() for row in batch],
tokenize=False,
add_generation_prompt=True,
for row, prompt in zip(
batch,
self.tokenizer.apply_chat_template(
[row["messages"].tolist() for row in batch],
tokenize=False,
add_generation_prompt=True,
),
):
yield {"prompt": prompt}
yield {
self.idx_in_batch_column: row[self.idx_in_batch_column],
"prompt": prompt,
}

@property
def expected_input_keys(self) -> List[str]:
Expand Down
15 changes: 12 additions & 3 deletions python/ray/llm/_internal/batch/stages/http_request_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,21 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]
headers=headers,
json=json_body,
)
pending_requests.append(request)
pending_requests.append((row[self.idx_in_batch_column], request))

# Now receive all responses
for request in pending_requests:
for idx_in_batch, request in pending_requests:
async with await request as response:
yield await response.json()
resp_json = await response.json()
if self.idx_in_batch_column in resp_json:
raise ValueError(
"The response of the HTTP request must not contain "
f"the column {self.idx_in_batch_column}."
)
yield {
self.idx_in_batch_column: idx_in_batch,
**resp_json,
}


class HttpRequestStage(StatefulStage):
Expand Down
28 changes: 20 additions & 8 deletions python/ray/llm/_internal/batch/stages/prepare_image_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,18 +342,30 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]
flat_all_image_info = [img for imgs in all_image_info for img in imgs]
flat_all_images = await self.image_processor.process(flat_all_image_info)

idx = 0
# TODO: We now use asyncio.gather to process all images in this batch,
# so the outputs here must be in order. However, it is more efficient
# to support out-of-order outputs so that we won't be blocked by slow
# downloaded images.
img_start_idx = 0
idx_in_batch = 0
for image_info_per_req in all_image_info:
num_images_in_req = len(image_info_per_req)
ret = {self.idx_in_batch_column: idx_in_batch}
if num_images_in_req == 0:
yield {}
yield ret
else:
images = flat_all_images[idx : idx + num_images_in_req]
yield {
"image": images,
"image_sizes": [(img.width, img.height) for img in images],
}
idx += num_images_in_req
images = flat_all_images[
img_start_idx : img_start_idx + num_images_in_req
]
ret.update(
{
"image": images,
"image_sizes": [(img.width, img.height) for img in images],
}
)
yield ret
img_start_idx += num_images_in_req
idx_in_batch += 1

@property
def expected_input_keys(self) -> List[str]:
Expand Down
10 changes: 8 additions & 2 deletions python/ray/llm/_internal/batch/stages/tokenize_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]
batch,
self.tokenizer([row["prompt"] for row in batch])["input_ids"],
):
yield {"tokenized_prompt": prompt_token_ids}
yield {
self.idx_in_batch_column: row[self.idx_in_batch_column],
"tokenized_prompt": prompt_token_ids,
}

@property
def expected_input_keys(self) -> List[str]:
Expand Down Expand Up @@ -96,7 +99,10 @@ async def udf(self, batch: List[Dict[str, Any]]) -> AsyncIterator[Dict[str, Any]
skip_special_tokens=True,
),
):
yield {"generated_text": generated_text}
yield {
self.idx_in_batch_column: row[self.idx_in_batch_column],
"generated_text": generated_text,
}

@property
def expected_input_keys(self) -> List[str]:
Expand Down
46 changes: 37 additions & 9 deletions python/ray/llm/tests/batch/stages/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,26 @@ def to_string(x: dict) -> dict:

class TestStatefulStageUDF:
class SimpleUDF(StatefulStageUDF):
def __init__(
self,
data_column: str,
udf_output_missing_idx_in_batch_column: bool = False,
):
super().__init__(data_column)
self.udf_output_missing_idx_in_batch_column = (
udf_output_missing_idx_in_batch_column
)

async def udf(
self, rows: list[Dict[str, Any]]
) -> AsyncIterator[Dict[str, Any]]:
for row in rows:
yield {"processed": row["value"] * 2}

# Intentionally output in a reversed order to test OOO.
for row in rows[::-1]:
ret = {"processed": row["value"] * 2}
if not self.udf_output_missing_idx_in_batch_column:
ret[self.idx_in_batch_column] = row[self.idx_in_batch_column]
yield ret

@property
def expected_input_keys(self) -> List[str]:
Expand All @@ -55,20 +70,20 @@ async def test_basic_processing(self):
udf = self.SimpleUDF(data_column="__data")

batch = {
"__data": [{"value": 1, "extra": "a"}, {"value": 2, "extra": "b"}],
"__data": [{"value": 1, "extra": 10}, {"value": 2, "extra": 20}],
}

results = []
async for result in udf(batch):
results.append(result)

assert len(results) == 2
assert results[0] == {
"__data": [{"processed": 2, "value": 1, "extra": "a"}],
}
assert results[1] == {
"__data": [{"processed": 4, "value": 2, "extra": "b"}],
}
for result in results:
data = result["__data"][0]
val = data["value"]
assert data["processed"] == val * 2
assert data["extra"] == 10 * val
assert data["value"] == val

@pytest.mark.asyncio
async def test_missing_data_column(self):
Expand All @@ -90,6 +105,19 @@ async def test_missing_required_key(self):
async for _ in udf(batch):
pass

@pytest.mark.asyncio
async def test_missing_idx_in_batch_column(self):
udf = self.SimpleUDF(
data_column="__data",
udf_output_missing_idx_in_batch_column=True,
)

batch = {"__data": [{"value": 1}]}

with pytest.raises(ValueError):
async for _ in udf(batch):
pass


def test_stateful_stage():
udf = TestStatefulStageUDF.SimpleUDF(data_column="__data")
Expand Down
64 changes: 30 additions & 34 deletions python/ray/llm/tests/batch/stages/test_chat_template_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,23 @@ async def test_chat_template_udf_basic(mock_tokenizer_setup):
model="test-model",
)

batch = [
{
"messages": MagicMock(
tolist=lambda: [{"role": "user", "content": "Hello AI"}]
)
}
]
batch = {
"__data": [
{
"messages": MagicMock(
tolist=lambda: [{"role": "user", "content": "Hello AI"}]
)
}
]
}

results = []
async for result in udf.udf(batch):
async for result in udf(batch):
results.append(result)

assert len(results) == 1
assert results[0] == {"prompt": "<chat>Hello AI</chat>"}
mock_tokenizer.apply_chat_template.assert_called_once_with(
[batch[0]["messages"].tolist()],
tokenize=False,
add_generation_prompt=True,
)
assert results[0]["__data"][0]["prompt"] == "<chat>Hello AI</chat>"
mock_tokenizer.apply_chat_template.assert_called_once()


@pytest.mark.asyncio
Expand All @@ -58,31 +56,29 @@ async def test_chat_template_udf_multiple_messages(mock_tokenizer_setup):
model="test-model",
)

batch = [
{
"messages": MagicMock(
tolist=lambda: [{"role": "user", "content": "Hello AI"}]
)
},
{
"messages": MagicMock(
tolist=lambda: [{"role": "user", "content": "How are you?"}]
)
},
]
batch = {
"__data": [
{
"messages": MagicMock(
tolist=lambda: [{"role": "user", "content": "Hello AI"}]
)
},
{
"messages": MagicMock(
tolist=lambda: [{"role": "user", "content": "How are you?"}],
)
},
]
}

results = []
async for result in udf.udf(batch):
async for result in udf(batch):
results.append(result)

assert len(results) == 2
assert results[0] == {"prompt": "<chat>Hello AI</chat>"}
assert results[1] == {"prompt": "<chat>How are you?</chat>"}
mock_tokenizer.apply_chat_template.assert_called_once_with(
[msg["messages"].tolist() for msg in batch],
tokenize=False,
add_generation_prompt=True,
)
assert results[0]["__data"][0]["prompt"] == "<chat>Hello AI</chat>"
assert results[1]["__data"][0]["prompt"] == "<chat>How are you?</chat>"
mock_tokenizer.apply_chat_template.assert_called_once()


def test_chat_template_udf_expected_input_keys(mock_tokenizer_setup):
Expand Down
12 changes: 6 additions & 6 deletions python/ray/llm/tests/batch/stages/test_http_request_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def test_http_request_udf_basic():
qps=None,
)

batch = [{"text": "hello", "metadata": "test"}]
batch = {"__data": [{"text": "hello", "metadata": "test"}]}

with patch("aiohttp.ClientSession") as mock_session_cls:
session = AsyncMock()
Expand All @@ -41,16 +41,16 @@ async def test_http_request_udf_basic():
)
mock_session_cls.return_value.__aenter__.return_value = session

async for result in udf.udf(batch):
assert result == {"response": "test"}
async for result in udf(batch):
assert result["__data"][0]["response"] == "test"

session.post.assert_called_once_with(
"http://test.com/api",
headers={
"Content-Type": "application/json",
"Authorization": "Bearer 1234567890",
},
json={"text": "hello", "metadata": "test"},
json={"text": "hello", "metadata": "test", "__idx_in_batch": 0},
)


Expand All @@ -62,7 +62,7 @@ async def test_http_request_udf_with_qps():
qps=2,
)

batch = [{"text": "hello1"}, {"text": "hello2"}]
batch = {"__data": [{"text": "hello1"}, {"text": "hello2"}]}

with patch("aiohttp.ClientSession") as mock_session_cls, patch(
"time.time"
Expand All @@ -80,7 +80,7 @@ async def test_http_request_udf_with_qps():
mock_time.side_effect = [0, 0.1, 0.2]

results = []
async for result in udf.udf(batch):
async for result in udf(batch):
results.append(result)

assert len(results) == 2
Expand Down
Loading