Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions responses_api_models/vllm_model/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from aiohttp.client_exceptions import ClientResponseError
from fastapi import Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field

from nemo_gym.base_responses_api_model import (
Expand Down Expand Up @@ -67,6 +68,8 @@ class VLLMModelConfig(BaseResponsesAPIModelConfig):
uses_reasoning_parser: bool
replace_developer_role_with_system: bool = False

use_native_responses_api: bool = False

chat_template_kwargs: Optional[Dict[str, Any]] = None

# Corresponds to the extra_body of OpenAI Client.
Expand Down Expand Up @@ -101,6 +104,133 @@ def model_post_init(self, context):
async def responses(
self, request: Request, body: NeMoGymResponseCreateParamsNonStreaming = Body()
) -> NeMoGymResponse:
session_id = request.session[SESSION_ID_KEY]
if session_id not in self._session_id_to_client:
client_idx = len(self._session_id_to_client) % len(self._clients)
client = self._clients[client_idx]
self._session_id_to_client[session_id] = client
client = self._session_id_to_client[session_id]

if self.config.use_native_responses_api:
body_dict = body.model_dump(exclude_unset=True)
body_dict["model"] = self.config.model

if self.config.return_token_id_information:
body_dict["top_logprobs"] = 1
if "include" not in body_dict:
body_dict["include"] = []
if "message.output_text.logprobs" not in body_dict["include"]:
body_dict["include"].append("message.output_text.logprobs")

if self.config.extra_body:
body_dict = {**self.config.extra_body, **body_dict}

try:
vllm_response_dict = await client.create_response(**body_dict)
except ClientResponseError as e:
result_content_str = e.response_content.decode()
is_out_of_context_length = e.status == 400 and (
"context length" in result_content_str or "max_tokens" in result_content_str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we double check that these are the same error patterns vllm will throw for responses as chat completions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

responses pattern:
seems to have 2 paths, harmony or not
harmony:
https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/engine/serving.py#L921
skips _preprocess_chat, calls _validate_generator_input = "max_model_len" error pattern (https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/responses/serving.py#L301)

non harmony:
calls preprocess_chat = "context length"
later calls validate generator input too
https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/responses/serving.py#L613

chat completions pattern:
_preprocess_chat: https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/chat_completion/serving.py#L297
calls _validate_input: https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/engine/serving.py#L921

resulting in "context length" in error msg

max_tokens seems to come from here https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/engine/serving.py#L959

So - i think we should keep "context length", "max_tokens" and also check for "max_model_len" for responses.

)
if is_out_of_context_length:
return NeMoGymResponse(
id=f"resp_{uuid4().hex}",
created_at=int(time()),
model=self.config.model,
object="response",
parallel_tool_calls=True,
tool_choice="auto",
tools=[],
output=[
NeMoGymResponseOutputMessage(
id=f"msg_{uuid4().hex}",
role="assistant",
content=[NeMoGymResponseOutputText(type="output_text", text="", annotations=[])],
status="completed",
type="message",
)
],
)
else:
raise e

if self.config.uses_reasoning_parser:
output = vllm_response_dict.get("output", [])
for output_item in output:
if output_item.get("type") == "message" and output_item.get("role") == "assistant":
content = output_item.get("content", [])
for content_item in content:
if content_item.get("type") == "output_text":
text = content_item.get("text", "")
reasoning_matches, cleaned_text = self._converter._extract_reasoning_from_content(text)
if reasoning_matches:
content_item["text"] = cleaned_text
reasoning_item = {
"id": f"rs_{uuid4().hex}",
"type": "reasoning",
"summary": [
{"text": reasoning_text, "type": "summary_text"}
for reasoning_text in reasoning_matches
],
"status": "completed",
}
output_idx = output.index(output_item)
output.insert(output_idx, reasoning_item)

if self.config.return_token_id_information:
output = vllm_response_dict.get("output", [])
for output_item in output:
if output_item.get("type") == "message" and output_item.get("role") == "assistant":
content = output_item.get("content", [])
new_content = []
for content_item in content:
if content_item.get("type") == "output_text":
logprobs = content_item.get("logprobs", [])
if logprobs:
generation_token_ids = []
generation_log_probs = []
for logprob_item in logprobs:
token = logprob_item.get("token", "")
if token.startswith("token_id:"):
token_id = token.removeprefix("token_id:")
else:
token_id = str(logprob_item.get("token_id", token))
generation_token_ids.append(token_id)
generation_log_probs.append(logprob_item.get("logprob", 0.0))

tokenize_body_dict = {"model": body_dict["model"]}
if "input" in body_dict:
tokenize_body_dict["messages"] = body_dict["input"]
if "tools" in body_dict:
tokenize_body_dict["tools"] = body_dict["tools"]

tokenize_response = await client.create_tokenize(**tokenize_body_dict)
prompt_token_ids = tokenize_response.get("tokens", [])

output_item["prompt_token_ids"] = prompt_token_ids
output_item["generation_token_ids"] = generation_token_ids
output_item["generation_log_probs"] = generation_log_probs

new_content_item = {
"type": content_item["type"],
"text": content_item["text"],
"annotations": content_item.get("annotations", []),
}
new_content.append(new_content_item)
else:
new_content.append(content_item)
else:
new_content.append(content_item)

if new_content:
output_item["content"] = new_content

validated_response = NeMoGymResponse.model_validate(vllm_response_dict)
return JSONResponse(
content=validated_response.model_dump(mode="json", exclude_none=True),
status_code=200
)

# Response Create Params -> Chat Completion Create Params
chat_completion_create_params = self._converter.responses_to_chat_completion_create_params(body)
body.model = self.config.model
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
policy_model:
responses_api_models:
vllm_model:
entrypoint: app.py
base_url: ${policy_base_url}
api_key: ${policy_api_key}
model: ${policy_model_name}
return_token_id_information: false
uses_reasoning_parser: true
use_native_responses_api: true
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
policy_model:
responses_api_models:
vllm_model:
entrypoint: app.py
base_url: ${policy_base_url}
api_key: ${policy_api_key}
model: ${policy_model_name}
return_token_id_information: true
uses_reasoning_parser: true
use_native_responses_api: true
Loading
Loading