Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 102 additions & 1 deletion responses_api_models/vllm_model/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class VLLMModelConfig(BaseResponsesAPIModelConfig):
uses_reasoning_parser: bool
replace_developer_role_with_system: bool = False

use_responses_endpoint: bool = False

chat_template_kwargs: Optional[Dict[str, Any]] = None

# Corresponds to the extra_body of OpenAI Client.
Expand Down Expand Up @@ -101,7 +103,106 @@ def model_post_init(self, context):
async def responses(
self, request: Request, body: NeMoGymResponseCreateParamsNonStreaming = Body()
) -> NeMoGymResponse:
# Response Create Params -> Chat Completion Create Params
session_id = request.session[SESSION_ID_KEY]
if session_id not in self._session_id_to_client:
client_idx = len(self._session_id_to_client) % len(self._clients)
client = self._clients[client_idx]
self._session_id_to_client[session_id] = client
client = self._session_id_to_client[session_id]

if self.config.use_responses_endpoint:
return await self._call_responses(client, body)

return await self._call_chat_completions(request, body)

async def _call_responses(
self, client: NeMoGymAsyncOpenAI, body: NeMoGymResponseCreateParamsNonStreaming
) -> NeMoGymResponse:
body_dict = body.model_dump(exclude_unset=True)
body_dict["model"] = self.config.model

if self.config.return_token_id_information:
body_dict["enable_response_messages"] = True
body_dict["top_logprobs"] = 1
if "include" not in body_dict:
body_dict["include"] = []
if "message.output_text.logprobs" not in body_dict["include"]:
body_dict["include"].append("message.output_text.logprobs")

if self.config.extra_body:
body_dict = {**self.config.extra_body, **body_dict}

try:
vllm_response_dict = await client.create_response(**body_dict)
except ClientResponseError as e:
result_content_str = e.response_content.decode()
is_out_of_context_length = e.status == 400 and (
"context length" in result_content_str
or "max_tokens" in result_content_str
or "max_model_len" in result_content_str
)
if is_out_of_context_length:
return NeMoGymResponse(
id=f"resp_{uuid4().hex}",
created_at=int(time()),
model=self.config.model,
object="response",
parallel_tool_calls=True,
tool_choice="auto",
tools=[],
output=[
NeMoGymResponseOutputMessage(
id=f"msg_{uuid4().hex}",
role="assistant",
content=[NeMoGymResponseOutputText(type="output_text", text="", annotations=[])],
status="completed",
type="message",
)
],
incomplete_details={"reason": "max_output_tokens"},
)
else:
raise

if self.config.return_token_id_information:
prompt_token_ids = vllm_response_dict["input_messages"][0]["tokens"]
generation_token_ids = vllm_response_dict["output_messages"][0]["tokens"]

output = vllm_response_dict.get("output", [])
for output_item in output:
if output_item.get("type") == "message" and output_item.get("role") == "assistant":
output_item["prompt_token_ids"] = prompt_token_ids
output_item["generation_token_ids"] = generation_token_ids

generation_log_probs = []
content = output_item.get("content", [])
new_content = []
for content_item in content:
if content_item.get("type") == "output_text":
logprobs = content_item.get("logprobs") or []
for logprob_item in logprobs:
generation_log_probs.append(logprob_item.get("logprob", 0.0))
new_content_item = {
"type": content_item["type"],
"text": content_item["text"],
"annotations": content_item.get("annotations", []),
}
new_content.append(new_content_item)
else:
new_content.append(content_item)
if new_content:
output_item["content"] = new_content
if generation_log_probs:
output_item["generation_log_probs"] = generation_log_probs

vllm_response_dict.pop("input_messages", None)
vllm_response_dict.pop("output_messages", None)

return NeMoGymResponse.model_validate(vllm_response_dict)

async def _call_chat_completions(
self, request: Request, body: NeMoGymResponseCreateParamsNonStreaming
) -> NeMoGymResponse:
chat_completion_create_params = self._converter.responses_to_chat_completion_create_params(body)
body.model = self.config.model

Expand Down
1 change: 1 addition & 0 deletions responses_api_models/vllm_model/configs/vllm_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ policy_model:
model: ${policy_model_name}
return_token_id_information: false
uses_reasoning_parser: true
use_responses_endpoint: false
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ policy_model:
model: ${policy_model_name}
return_token_id_information: true
uses_reasoning_parser: true
use_responses_endpoint: false
Loading