diff --git a/responses_api_models/vllm_model/app.py b/responses_api_models/vllm_model/app.py index 46319303d..feaca789c 100644 --- a/responses_api_models/vllm_model/app.py +++ b/responses_api_models/vllm_model/app.py @@ -67,6 +67,8 @@ class VLLMModelConfig(BaseResponsesAPIModelConfig): uses_reasoning_parser: bool replace_developer_role_with_system: bool = False + use_responses_endpoint: bool = False + chat_template_kwargs: Optional[Dict[str, Any]] = None # Corresponds to the extra_body of OpenAI Client. @@ -101,7 +103,106 @@ def model_post_init(self, context): async def responses( self, request: Request, body: NeMoGymResponseCreateParamsNonStreaming = Body() ) -> NeMoGymResponse: - # Response Create Params -> Chat Completion Create Params + session_id = request.session[SESSION_ID_KEY] + if session_id not in self._session_id_to_client: + client_idx = len(self._session_id_to_client) % len(self._clients) + client = self._clients[client_idx] + self._session_id_to_client[session_id] = client + client = self._session_id_to_client[session_id] + + if self.config.use_responses_endpoint: + return await self._call_responses(client, body) + + return await self._call_chat_completions(request, body) + + async def _call_responses( + self, client: NeMoGymAsyncOpenAI, body: NeMoGymResponseCreateParamsNonStreaming + ) -> NeMoGymResponse: + body_dict = body.model_dump(exclude_unset=True) + body_dict["model"] = self.config.model + + if self.config.return_token_id_information: + body_dict["enable_response_messages"] = True + body_dict["top_logprobs"] = 1 + if "include" not in body_dict: + body_dict["include"] = [] + if "message.output_text.logprobs" not in body_dict["include"]: + body_dict["include"].append("message.output_text.logprobs") + + if self.config.extra_body: + body_dict = {**self.config.extra_body, **body_dict} + + try: + vllm_response_dict = await client.create_response(**body_dict) + except ClientResponseError as e: + result_content_str = e.response_content.decode() + is_out_of_context_length = e.status == 400 and ( + "context length" in result_content_str + or "max_tokens" in result_content_str + or "max_model_len" in result_content_str + ) + if is_out_of_context_length: + return NeMoGymResponse( + id=f"resp_{uuid4().hex}", + created_at=int(time()), + model=self.config.model, + object="response", + parallel_tool_calls=True, + tool_choice="auto", + tools=[], + output=[ + NeMoGymResponseOutputMessage( + id=f"msg_{uuid4().hex}", + role="assistant", + content=[NeMoGymResponseOutputText(type="output_text", text="", annotations=[])], + status="completed", + type="message", + ) + ], + incomplete_details={"reason": "max_output_tokens"}, + ) + else: + raise + + if self.config.return_token_id_information: + prompt_token_ids = vllm_response_dict["input_messages"][0]["tokens"] + generation_token_ids = vllm_response_dict["output_messages"][0]["tokens"] + + output = vllm_response_dict.get("output", []) + for output_item in output: + if output_item.get("type") == "message" and output_item.get("role") == "assistant": + output_item["prompt_token_ids"] = prompt_token_ids + output_item["generation_token_ids"] = generation_token_ids + + generation_log_probs = [] + content = output_item.get("content", []) + new_content = [] + for content_item in content: + if content_item.get("type") == "output_text": + logprobs = content_item.get("logprobs") or [] + for logprob_item in logprobs: + generation_log_probs.append(logprob_item.get("logprob", 0.0)) + new_content_item = { + "type": content_item["type"], + "text": content_item["text"], + "annotations": content_item.get("annotations", []), + } + new_content.append(new_content_item) + else: + new_content.append(content_item) + if new_content: + output_item["content"] = new_content + if generation_log_probs: + output_item["generation_log_probs"] = generation_log_probs + + vllm_response_dict.pop("input_messages", None) + vllm_response_dict.pop("output_messages", None) + + return NeMoGymResponse.model_validate(vllm_response_dict) + + async def _call_chat_completions( + self, request: Request, body: NeMoGymResponseCreateParamsNonStreaming + ) -> NeMoGymResponse: chat_completion_create_params = self._converter.responses_to_chat_completion_create_params(body) body.model = self.config.model diff --git a/responses_api_models/vllm_model/configs/vllm_model.yaml b/responses_api_models/vllm_model/configs/vllm_model.yaml index f7850d900..c77ddcbfe 100644 --- a/responses_api_models/vllm_model/configs/vllm_model.yaml +++ b/responses_api_models/vllm_model/configs/vllm_model.yaml @@ -7,3 +7,4 @@ policy_model: model: ${policy_model_name} return_token_id_information: false uses_reasoning_parser: true + use_responses_endpoint: false diff --git a/responses_api_models/vllm_model/configs/vllm_model_for_training.yaml b/responses_api_models/vllm_model/configs/vllm_model_for_training.yaml index 70727036c..be13f5ae2 100644 --- a/responses_api_models/vllm_model/configs/vllm_model_for_training.yaml +++ b/responses_api_models/vllm_model/configs/vllm_model_for_training.yaml @@ -7,3 +7,4 @@ policy_model: model: ${policy_model_name} return_token_id_information: true uses_reasoning_parser: true + use_responses_endpoint: false diff --git a/responses_api_models/vllm_model/tests/test_app.py b/responses_api_models/vllm_model/tests/test_app.py index 27f7ade2c..5a0d23103 100644 --- a/responses_api_models/vllm_model/tests/test_app.py +++ b/responses_api_models/vllm_model/tests/test_app.py @@ -661,7 +661,13 @@ class FakeUUID: class TestApp: - def _setup_server(self, monkeypatch: MonkeyPatch): + def _setup_server( + self, + monkeypatch: MonkeyPatch, + return_token_id_information: bool = False, + uses_reasoning_parser: bool = False, + use_responses_endpoint: bool = False, + ): config = VLLMModelConfig( host="0.0.0.0", port=8081, @@ -670,8 +676,9 @@ def _setup_server(self, monkeypatch: MonkeyPatch): model="dummy_model", entrypoint="", name="", - return_token_id_information=False, - uses_reasoning_parser=False, + return_token_id_information=return_token_id_information, + uses_reasoning_parser=uses_reasoning_parser, + use_responses_endpoint=use_responses_endpoint, ) get_global_config_dict_mock = MagicMock() @@ -2039,6 +2046,364 @@ def test_responses_reasoning_parser(self, monkeypatch: MonkeyPatch): actual_messages = mock_method.call_args.kwargs["messages"] assert expected_messages == actual_messages + def test_native_responses_api_basic(self, monkeypatch: MonkeyPatch): + server = self._setup_server(monkeypatch, use_responses_endpoint=True) + app = server.setup_webserver() + client = TestClient(app) + + mock_vllm_response = { + "id": "resp_native_123", + "created_at": FIXED_TIME, + "model": "dummy_model", + "object": "response", + "parallel_tool_calls": True, + "tool_choice": "auto", + "tools": [], + "output": [ + { + "id": "msg_456", + "type": "message", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": "Hello! How can I help you?", + "annotations": [], + } + ], + "status": "completed", + } + ], + } + + mock_create_response = AsyncMock(return_value=mock_vllm_response) + monkeypatch.setattr(NeMoGymAsyncOpenAI, "create_response", mock_create_response) + + request_body = NeMoGymResponseCreateParamsNonStreaming(input="What is the weather?") + + response = client.post( + "/v1/responses", + json=request_body.model_dump(exclude_unset=True, mode="json"), + ) + assert response.status_code == 200 + + data = response.json() + assert data["output"][0]["content"][0]["text"] == "Hello! How can I help you?" + assert mock_create_response.called + assert mock_create_response.call_args.kwargs["model"] == "dummy_model" + + def test_native_responses_api_with_reasoning(self, monkeypatch: MonkeyPatch): + server = self._setup_server(monkeypatch, use_responses_endpoint=True) + app = server.setup_webserver() + client = TestClient(app) + + mock_vllm_response = { + "id": "resp_native_123", + "created_at": FIXED_TIME, + "model": "dummy_model", + "object": "response", + "parallel_tool_calls": True, + "tool_choice": "auto", + "tools": [], + "output": [ + { + "id": "rs_123", + "type": "reasoning", + "summary": [{"type": "summary_text", "text": "I should check the weather for the user"}], + }, + { + "id": "msg_456", + "type": "message", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": "Let me help you with that!", + "annotations": [], + } + ], + "status": "completed", + }, + ], + } + + mock_create_response = AsyncMock(return_value=mock_vllm_response) + monkeypatch.setattr(NeMoGymAsyncOpenAI, "create_response", mock_create_response) + + request_body = NeMoGymResponseCreateParamsNonStreaming(input="What is the weather?") + + response = client.post( + "/v1/responses", + json=request_body.model_dump(exclude_unset=True, mode="json"), + ) + assert response.status_code == 200 + + data = response.json() + + assert len(data["output"]) == 2 + assert data["output"][0]["type"] == "reasoning" + assert data["output"][0]["summary"][0]["text"] == "I should check the weather for the user" + assert data["output"][1]["type"] == "message" + assert data["output"][1]["content"][0]["text"] == "Let me help you with that!" + + def test_native_responses_api_with_token_ids(self, monkeypatch: MonkeyPatch): + server = self._setup_server(monkeypatch, return_token_id_information=True, use_responses_endpoint=True) + app = server.setup_webserver() + client = TestClient(app) + + mock_vllm_response = { + "id": "resp_native_123", + "created_at": FIXED_TIME, + "model": "dummy_model", + "object": "response", + "parallel_tool_calls": True, + "tool_choice": "auto", + "tools": [], + "output": [ + { + "id": "msg_456", + "type": "message", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": "Hello!", + "annotations": [], + "logprobs": [ + {"token": "Hello", "logprob": -0.5}, + {"token": "!", "logprob": -1.2}, + ], + } + ], + "status": "completed", + } + ], + "input_messages": [{"tokens": [1, 2, 3, 4, 5], "type": "raw_message_tokens"}], + "output_messages": [{"tokens": [100, 200], "type": "raw_message_tokens"}], + } + + mock_create_response = AsyncMock(return_value=mock_vllm_response) + + monkeypatch.setattr(NeMoGymAsyncOpenAI, "create_response", mock_create_response) + + request_body = NeMoGymResponseCreateParamsNonStreaming(input="What is the weather?") + + response = client.post( + "/v1/responses", + json=request_body.model_dump(exclude_unset=True, mode="json"), + ) + assert response.status_code == 200 + + data = response.json() + output_item = data["output"][0] + + assert "prompt_token_ids" in output_item + assert output_item["prompt_token_ids"] == [1, 2, 3, 4, 5] + assert "generation_token_ids" in output_item + assert output_item["generation_token_ids"] == [100, 200] + assert "generation_log_probs" in output_item + assert output_item["generation_log_probs"] == [-0.5, -1.2] + assert output_item["content"][0].get("logprobs") is None + assert "input_messages" not in data + assert "output_messages" not in data + assert mock_create_response.called + + def test_native_responses_api_tool_calls(self, monkeypatch: MonkeyPatch): + server = self._setup_server(monkeypatch, use_responses_endpoint=True) + app = server.setup_webserver() + client = TestClient(app) + + mock_vllm_response = { + "id": "resp_native_tools", + "created_at": FIXED_TIME, + "model": "dummy_model", + "object": "response", + "parallel_tool_calls": True, + "tool_choice": "auto", + "tools": [], + "output": [ + { + "id": "fc_123", + "type": "function_call", + "name": "get_weather", + "arguments": '{"location": "San Francisco"}', + "call_id": "call_123", + "status": "completed", + } + ], + } + + mock_create_response = AsyncMock(return_value=mock_vllm_response) + monkeypatch.setattr(NeMoGymAsyncOpenAI, "create_response", mock_create_response) + + request_body = NeMoGymResponseCreateParamsNonStreaming(input="What is the weather in San Francisco?") + + response = client.post( + "/v1/responses", + json=request_body.model_dump(exclude_unset=True, mode="json"), + ) + assert response.status_code == 200 + + data = response.json() + assert data["output"][0]["type"] == "function_call" + assert data["output"][0]["name"] == "get_weather" + assert data["output"][0]["arguments"] == '{"location": "San Francisco"}' + assert mock_create_response.called + + def test_native_responses_api_multiturn(self, monkeypatch: MonkeyPatch): + server = self._setup_server(monkeypatch, use_responses_endpoint=True) + app = server.setup_webserver() + client = TestClient(app) + + mock_vllm_response = { + "id": "resp_native_multiturn", + "created_at": FIXED_TIME, + "model": "dummy_model", + "object": "response", + "parallel_tool_calls": True, + "tool_choice": "auto", + "tools": [], + "output": [ + { + "id": "msg_789", + "type": "message", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": "The capital of France is Paris.", + "annotations": [], + } + ], + "status": "completed", + } + ], + } + + mock_create_response = AsyncMock(return_value=mock_vllm_response) + monkeypatch.setattr(NeMoGymAsyncOpenAI, "create_response", mock_create_response) + + request_body = NeMoGymResponseCreateParamsNonStreaming( + input=[ + {"role": "user", "content": "What is the capital of Germany?"}, + {"role": "assistant", "content": "The capital of Germany is Berlin."}, + {"role": "user", "content": "What about France?"}, + ] + ) + + response = client.post( + "/v1/responses", + json=request_body.model_dump(exclude_unset=True, mode="json"), + ) + assert response.status_code == 200 + + data = response.json() + assert data["output"][0]["content"][0]["text"] == "The capital of France is Paris." + assert mock_create_response.called + # Verify the input was passed through + assert mock_create_response.call_args.kwargs["input"] == [ + {"role": "user", "content": "What is the capital of Germany?"}, + {"role": "assistant", "content": "The capital of Germany is Berlin."}, + {"role": "user", "content": "What about France?"}, + ] + + def test_native_responses_api_string_input(self, monkeypatch: MonkeyPatch): + server = self._setup_server(monkeypatch, use_responses_endpoint=True) + app = server.setup_webserver() + client = TestClient(app) + + mock_vllm_response = { + "id": "resp_native_str", + "created_at": FIXED_TIME, + "model": "dummy_model", + "object": "response", + "parallel_tool_calls": True, + "tool_choice": "auto", + "tools": [], + "output": [ + { + "id": "msg_str", + "type": "message", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": "Hello there!", + "annotations": [], + } + ], + "status": "completed", + } + ], + } + + mock_create_response = AsyncMock(return_value=mock_vllm_response) + monkeypatch.setattr(NeMoGymAsyncOpenAI, "create_response", mock_create_response) + + request_body = NeMoGymResponseCreateParamsNonStreaming(input="Hello") + + response = client.post( + "/v1/responses", + json=request_body.model_dump(exclude_unset=True, mode="json"), + ) + assert response.status_code == 200 + + data = response.json() + assert data["output"][0]["content"][0]["text"] == "Hello there!" + assert mock_create_response.called + assert mock_create_response.call_args.kwargs["input"] == "Hello" + + def test_native_responses_api_with_instructions(self, monkeypatch: MonkeyPatch): + server = self._setup_server(monkeypatch, use_responses_endpoint=True) + app = server.setup_webserver() + client = TestClient(app) + + mock_vllm_response = { + "id": "resp_native_inst", + "created_at": FIXED_TIME, + "model": "dummy_model", + "object": "response", + "parallel_tool_calls": True, + "tool_choice": "auto", + "tools": [], + "output": [ + { + "id": "msg_inst", + "type": "message", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": "Ahoy! How can I help ye today?", + "annotations": [], + } + ], + "status": "completed", + } + ], + } + + mock_create_response = AsyncMock(return_value=mock_vllm_response) + monkeypatch.setattr(NeMoGymAsyncOpenAI, "create_response", mock_create_response) + + request_body = NeMoGymResponseCreateParamsNonStreaming( + input="Hello", + instructions="You are a pirate. Always respond like a pirate.", + ) + + response = client.post( + "/v1/responses", + json=request_body.model_dump(exclude_unset=True, mode="json"), + ) + assert response.status_code == 200 + + data = response.json() + assert "Ahoy" in data["output"][0]["content"][0]["text"] + assert mock_create_response.called + assert ( + mock_create_response.call_args.kwargs["instructions"] == "You are a pirate. Always respond like a pirate." + ) + class TestVLLMConverter: def setup_method(self, _):