diff --git a/src/fast_agent/agents/workflow/agents_as_tools_agent.py b/src/fast_agent/agents/workflow/agents_as_tools_agent.py index 3e5bfc827..414901c97 100644 --- a/src/fast_agent/agents/workflow/agents_as_tools_agent.py +++ b/src/fast_agent/agents/workflow/agents_as_tools_agent.py @@ -243,7 +243,7 @@ class AgentsAsToolsOptions: history_mode: HistoryMode = HistoryMode.FORK max_parallel: int | None = None - child_timeout_sec: int | None = None + child_timeout_sec: float | None = None max_display_instances: int = 20 def __post_init__(self) -> None: diff --git a/src/fast_agent/interfaces.py b/src/fast_agent/interfaces.py index e6b5ce9f0..5ed3a40fa 100644 --- a/src/fast_agent/interfaces.py +++ b/src/fast_agent/interfaces.py @@ -92,6 +92,8 @@ def get_request_params( request_params: RequestParams | None = None, ) -> RequestParams: ... + default_request_params: RequestParams + def add_stream_listener(self, listener: Callable[[StreamChunk], None]) -> Callable[[], None]: ... def add_tool_stream_listener( diff --git a/src/fast_agent/mcp/mcp_content.py b/src/fast_agent/mcp/mcp_content.py index ac5cd24a2..f99ae687a 100644 --- a/src/fast_agent/mcp/mcp_content.py +++ b/src/fast_agent/mcp/mcp_content.py @@ -161,6 +161,7 @@ def MCPPrompt( Path, bytes, ContentBlock, + ResourceContents, ReadResourceResult, PromptMessage, PromptMessageExtended, @@ -262,6 +263,7 @@ def User( Path, bytes, ContentBlock, + ResourceContents, ReadResourceResult, PromptMessage, PromptMessageExtended, @@ -278,6 +280,7 @@ def Assistant( Path, bytes, ContentBlock, + ResourceContents, ReadResourceResult, PromptMessage, PromptMessageExtended, diff --git a/src/fast_agent/mcp/prompt.py b/src/fast_agent/mcp/prompt.py index 0e8031701..87e26a91b 100644 --- a/src/fast_agent/mcp/prompt.py +++ b/src/fast_agent/mcp/prompt.py @@ -10,7 +10,7 @@ from typing import Literal, Union from mcp import CallToolRequest -from mcp.types import ContentBlock, PromptMessage +from mcp.types import ContentBlock, PromptMessage, ReadResourceResult, ResourceContents from fast_agent.mcp.mcp_content import Assistant, MCPPrompt, User from fast_agent.types import LlmStopReason, PromptMessageExtended @@ -39,7 +39,15 @@ class Prompt: def user( cls, *content_items: Union[ - str, Path, bytes, dict, ContentBlock, PromptMessage, PromptMessageExtended + str, + Path, + bytes, + dict, + ContentBlock, + ResourceContents, + ReadResourceResult, + PromptMessage, + PromptMessageExtended, ], ) -> PromptMessageExtended: """ @@ -62,7 +70,15 @@ def user( def assistant( cls, *content_items: Union[ - str, Path, bytes, dict, ContentBlock, PromptMessage, PromptMessageExtended + str, + Path, + bytes, + dict, + ContentBlock, + ResourceContents, + ReadResourceResult, + PromptMessage, + PromptMessageExtended, ], stop_reason: LlmStopReason | None = None, tool_calls: dict[str, CallToolRequest] | None = None, @@ -102,7 +118,15 @@ def assistant( def message( cls, *content_items: Union[ - str, Path, bytes, dict, ContentBlock, PromptMessage, PromptMessageExtended + str, + Path, + bytes, + dict, + ContentBlock, + ResourceContents, + ReadResourceResult, + PromptMessage, + PromptMessageExtended, ], role: Literal["user", "assistant"] = "user", ) -> PromptMessageExtended: diff --git a/tests/e2e/llm/test_llm_e2e_reasoning.py b/tests/e2e/llm/test_llm_e2e_reasoning.py index 7ce9cff89..c08bc177b 100644 --- a/tests/e2e/llm/test_llm_e2e_reasoning.py +++ b/tests/e2e/llm/test_llm_e2e_reasoning.py @@ -46,6 +46,7 @@ def on_chunk(chunk: StreamChunk) -> None: async def _run_turn(agent: LlmAgent, prompt: str) -> tuple[dict[str, int], list[str], str | None]: listener, state = _make_stream_tracker() + assert agent.llm is not None remove = agent.llm.add_stream_listener(listener) try: result = await agent.generate(prompt) diff --git a/tests/e2e/smoke/tensorzero/test_agent_interaction.py b/tests/e2e/smoke/tensorzero/test_agent_interaction.py index 889d7a2d9..d9dfb071a 100644 --- a/tests/e2e/smoke/tensorzero/test_agent_interaction.py +++ b/tests/e2e/smoke/tensorzero/test_agent_interaction.py @@ -48,7 +48,7 @@ async def dummy_agent_func(): async with fast.run() as agent_app: agent_instance = agent_app.default - print(f"\nSending {len(messages_to_send)} messages to agent '{agent_instance._name}'...") + print(f"\nSending {len(messages_to_send)} messages to agent '{agent_instance.name}'...") for i, msg_text in enumerate(messages_to_send): print(f"Sending message {i + 1}: '{msg_text}'") await agent_instance.send(msg_text) diff --git a/tests/e2e/smoke/tensorzero/test_simple_agent_interaction.py b/tests/e2e/smoke/tensorzero/test_simple_agent_interaction.py index 311148014..098dfff1b 100644 --- a/tests/e2e/smoke/tensorzero/test_simple_agent_interaction.py +++ b/tests/e2e/smoke/tensorzero/test_simple_agent_interaction.py @@ -35,8 +35,8 @@ async def dummy_simple_agent_func(): async with fast.run() as agent_app: agent_instance = agent_app.simple_default - print(f"\nSending message to agent '{agent_instance._name}': '{message_to_send}'") + print(f"\nSending message to agent '{agent_instance.name}': '{message_to_send}'") await agent_instance.send(message_to_send) - print(f"Message sent successfully to '{agent_instance._name}'.") + print(f"Message sent successfully to '{agent_instance.name}'.") print("\nSimple agent interaction smoke test completed successfully.") diff --git a/tests/integration/acp/test_acp_skills_manager.py b/tests/integration/acp/test_acp_skills_manager.py index 0042f91ba..c5925ec1d 100644 --- a/tests/integration/acp/test_acp_skills_manager.py +++ b/tests/integration/acp/test_acp_skills_manager.py @@ -243,7 +243,8 @@ async def test_skills_registry_numbered_selection(tmp_path: Path) -> None: assert "Registry set to" in response assert get_settings().skills.marketplace_url == marketplace2.as_posix() # Configured list is preserved - assert len(get_settings().skills.marketplace_urls) == 2 + marketplace_urls = get_settings().skills.marketplace_urls or [] + assert len(marketplace_urls) == 2 # Invalid number response = await handler.execute_command("skills", "registry 99") diff --git a/tests/integration/acp/test_acp_status_line.py b/tests/integration/acp/test_acp_status_line.py index 6bfe88d55..08159f779 100644 --- a/tests/integration/acp/test_acp_status_line.py +++ b/tests/integration/acp/test_acp_status_line.py @@ -4,6 +4,7 @@ import re import sys from pathlib import Path +from typing import cast import pytest from acp.helpers import text_block @@ -37,14 +38,17 @@ def _extract_status_line(meta: object) -> str | None: if not isinstance(meta, dict): return None - field_meta = meta.get("field_meta") + meta_dict = cast("dict[str, object]", meta) + field_meta = meta_dict.get("field_meta") if isinstance(field_meta, dict): - metrics = field_meta.get("openhands.dev/metrics") + field_meta_dict = cast("dict[str, object]", field_meta) + metrics = field_meta_dict.get("openhands.dev/metrics") else: - metrics = meta.get("openhands.dev/metrics") + metrics = meta_dict.get("openhands.dev/metrics") if not isinstance(metrics, dict): return None - status_line = metrics.get("status_line") + metrics_dict = cast("dict[str, object]", metrics) + status_line = metrics_dict.get("status_line") if isinstance(status_line, str) and status_line.strip(): return status_line return None diff --git a/tests/integration/acp/test_acp_terminal.py b/tests/integration/acp/test_acp_terminal.py index 9b2d0a90c..178250b71 100644 --- a/tests/integration/acp/test_acp_terminal.py +++ b/tests/integration/acp/test_acp_terminal.py @@ -130,7 +130,7 @@ async def test_acp_terminal_execution() -> None: # Manually test terminal lifecycle (client creates ID) create_result = await client.create_terminal(command="echo test", session_id=session_id) - terminal_id = create_result.terminalId + terminal_id = create_result.terminal_id # Verify terminal was created with client-generated ID assert terminal_id == "terminal-1" # First terminal @@ -141,7 +141,7 @@ async def test_acp_terminal_execution() -> None: output = await client.terminal_output(session_id=session_id, terminal_id=terminal_id) assert "Executed: echo test" in output.output exit_info = await client.wait_for_terminal_exit(session_id=session_id, terminal_id=terminal_id) - assert exit_info.exitCode == 0 + assert exit_info.exit_code == 0 # Release terminal await client.release_terminal(session_id=session_id, terminal_id=terminal_id) diff --git a/tests/integration/acp/test_acp_terminal_lifecycle.py b/tests/integration/acp/test_acp_terminal_lifecycle.py index 8a315814f..c0b3584c2 100644 --- a/tests/integration/acp/test_acp_terminal_lifecycle.py +++ b/tests/integration/acp/test_acp_terminal_lifecycle.py @@ -34,7 +34,7 @@ async def test_terminal_create_lifecycle() -> None: # Create first terminal result1 = await client.create_terminal(command="echo hello", session_id="test-session") - terminal_id1 = result1.terminalId + terminal_id1 = result1.terminal_id assert terminal_id1 == "terminal-1" assert len(client.terminals) == 1 @@ -42,7 +42,7 @@ async def test_terminal_create_lifecycle() -> None: # Create second terminal result2 = await client.create_terminal(command="pwd", session_id="test-session") - terminal_id2 = result2.terminalId + terminal_id2 = result2.terminal_id assert terminal_id2 == "terminal-2" assert len(client.terminals) == 2 @@ -68,7 +68,7 @@ async def test_terminal_output_retrieval() -> None: # Create terminal result = await client.create_terminal(command="echo test output", session_id="test-session") - terminal_id = result.terminalId + terminal_id = result.terminal_id # Get output output = await client.terminal_output(session_id="test-session", terminal_id=terminal_id) @@ -88,14 +88,14 @@ async def test_terminal_wait_for_exit() -> None: # Create terminal result = await client.create_terminal(command="echo test", session_id="test-session") - terminal_id = result.terminalId + terminal_id = result.terminal_id # Wait for exit (immediate in test client) exit_result = await client.wait_for_terminal_exit( session_id="test-session", terminal_id=terminal_id ) - assert exit_result.exitCode == 0 + assert exit_result.exit_code == 0 assert exit_result.signal is None # Cleanup @@ -110,7 +110,7 @@ async def test_terminal_kill() -> None: # Create terminal result = await client.create_terminal(command="sleep 100", session_id="test-session") - terminal_id = result.terminalId + terminal_id = result.terminal_id # Kill it await client.kill_terminal(session_id="test-session", terminal_id=terminal_id) @@ -123,7 +123,7 @@ async def test_terminal_kill() -> None: exit_result = await client.wait_for_terminal_exit( session_id="test-session", terminal_id=terminal_id ) - assert exit_result.exitCode is None + assert exit_result.exit_code is None assert exit_result.signal == "SIGKILL" # Cleanup @@ -140,7 +140,7 @@ async def test_terminal_release_cleanup() -> None: terminals = [] for i in range(3): result = await client.create_terminal(command=f"echo {i}", session_id="test-session") - terminals.append(result.terminalId) + terminals.append(result.terminal_id) assert len(client.terminals) == 3 @@ -171,7 +171,7 @@ async def test_terminal_missing_id() -> None: exit_result = await client.wait_for_terminal_exit( session_id="test-session", terminal_id="missing" ) - assert exit_result.exitCode is None + assert exit_result.exit_code is None # Kill non-existent terminal (should not error) await client.kill_terminal(session_id="test-session", terminal_id="missing") diff --git a/tests/integration/acp/test_client.py b/tests/integration/acp/test_client.py index 6808004ca..04d245ac1 100644 --- a/tests/integration/acp/test_client.py +++ b/tests/integration/acp/test_client.py @@ -52,7 +52,7 @@ def queue_permission_cancelled(self) -> None: def queue_permission_selected(self, option_id: str) -> None: self.permission_outcomes.append( RequestPermissionResponse( - outcome=AllowedOutcome(optionId=option_id, outcome="selected") + outcome=AllowedOutcome(option_id=option_id, outcome="selected") ) ) @@ -149,7 +149,7 @@ async def create_terminal( } # Return the ID we created - return CreateTerminalResponse(terminalId=terminal_id) + return CreateTerminalResponse(terminal_id=terminal_id) async def terminal_output( self, @@ -161,9 +161,9 @@ async def terminal_output( terminal = self.terminals.get(terminal_id, {}) exit_code = terminal.get("exit_code") if isinstance(exit_code, int) and exit_code >= 0: - exit_status = TerminalExitStatus(exitCode=exit_code) + exit_status = TerminalExitStatus(exit_code=exit_code) elif isinstance(exit_code, int) and exit_code < 0: - exit_status = TerminalExitStatus(exitCode=None, signal="SIGKILL") + exit_status = TerminalExitStatus(exit_code=None, signal="SIGKILL") else: exit_status = None @@ -194,10 +194,12 @@ async def wait_for_terminal_exit( terminal = self.terminals.get(terminal_id, {}) exit_code = terminal.get("exit_code") if isinstance(exit_code, int) and exit_code >= 0: - return WaitForTerminalExitResponse(exitCode=exit_code, signal=None) + return WaitForTerminalExitResponse(exit_code=exit_code, signal=None) # Unknown or negative exit -> model as killed/terminated with no exit code - return WaitForTerminalExitResponse(exitCode=None, signal="SIGKILL" if exit_code else None) + return WaitForTerminalExitResponse( + exit_code=None, signal="SIGKILL" if exit_code else None + ) async def kill_terminal( self, diff --git a/tests/integration/acp/test_set_model_validation.py b/tests/integration/acp/test_set_model_validation.py index 54d9f201b..c8b7e8959 100644 --- a/tests/integration/acp/test_set_model_validation.py +++ b/tests/integration/acp/test_set_model_validation.py @@ -183,8 +183,10 @@ async def test_validate_resolves_aliases_to_hf_models() -> None: resolved_model = model break - if hf_alias is None: + if hf_alias is None or resolved_model is None: pytest.skip("No HF model aliases found in MODEL_ALIASES") + assert hf_alias is not None + assert resolved_model is not None # Extract the expected model ID from the resolved model expected_model_id = resolved_model[3:] # Strip "hf." diff --git a/tests/integration/api/mcp_tools_server.py b/tests/integration/api/mcp_tools_server.py index 0c0f70d96..028873251 100644 --- a/tests/integration/api/mcp_tools_server.py +++ b/tests/integration/api/mcp_tools_server.py @@ -43,9 +43,9 @@ def shirt_colour() -> str: @app.tool(name="implementation", description="Returns the Client implementation") def implementation(ctx: Context) -> str: assert ctx.session.client_params is not None, "Client params should not be None" - clientInfo = ctx.session.client_params.clientInfo or None - - return clientInfo.model_dump_json() + client_info = ctx.session.client_params.clientInfo + assert client_info is not None + return client_info.model_dump_json() if __name__ == "__main__": diff --git a/tests/integration/api/test_logger_textio.py b/tests/integration/api/test_logger_textio.py index 38cb18dde..376e884aa 100644 --- a/tests/integration/api/test_logger_textio.py +++ b/tests/integration/api/test_logger_textio.py @@ -75,6 +75,7 @@ def test_logger_textio_real_process(test_script_path, logger_io): ) # Read and process stderr lines + assert process.stderr is not None for line in process.stderr: logger_io.write(line) diff --git a/tests/integration/api/test_retry_error_channel.py b/tests/integration/api/test_retry_error_channel.py index 725ebcf36..3c7ab0de2 100644 --- a/tests/integration/api/test_retry_error_channel.py +++ b/tests/integration/api/test_retry_error_channel.py @@ -15,8 +15,8 @@ class FailingOpenAILLM(OpenAILLM): """Test double that always raises an APIError.""" - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, provider=Provider.OPENAI, **kwargs) + def __init__(self, **kwargs) -> None: + super().__init__(provider=Provider.OPENAI, **kwargs) self.attempts = 0 async def _apply_prompt_provider_specific( @@ -40,7 +40,8 @@ async def test_retry_exhaustion_returns_error_channel(): assert llm.attempts == 1 # no retries when FAST_AGENT_RETRIES=0 assert response.stop_reason == LlmStopReason.ERROR - assert FAST_AGENT_ERROR_CHANNEL in (response.channels or {}) + assert response.channels is not None + assert FAST_AGENT_ERROR_CHANNEL in response.channels error_block = response.channels[FAST_AGENT_ERROR_CHANNEL][0] assert "request failed" in (get_text(error_block) or "") diff --git a/tests/integration/elicitation/test_elicitation_handler.py b/tests/integration/elicitation/test_elicitation_handler.py index aa79739ab..514e8a4bc 100644 --- a/tests/integration/elicitation/test_elicitation_handler.py +++ b/tests/integration/elicitation/test_elicitation_handler.py @@ -25,9 +25,10 @@ async def custom_elicitation_handler( """Test handler that returns predictable responses for integration testing.""" logger.info(f"Test elicitation handler called with: {params.message}") - if params.requestedSchema: + requested_schema = getattr(params, "requestedSchema", None) + if requested_schema: # Generate test data based on the schema for round-trip verification - properties = params.requestedSchema.get("properties", {}) + properties = requested_schema.get("properties", {}) content: dict[str, Any] = {} # Provide test values for each field diff --git a/tests/integration/elicitation/test_elicitation_integration.py b/tests/integration/elicitation/test_elicitation_integration.py index b8f630273..f7e45d4a7 100644 --- a/tests/integration/elicitation/test_elicitation_integration.py +++ b/tests/integration/elicitation/test_elicitation_integration.py @@ -28,9 +28,10 @@ async def custom_test_elicitation_handler( """Test handler that returns predictable responses for integration testing.""" logger.info(f"Test elicitation handler called with: {params.message}") - if params.requestedSchema: + requested_schema = getattr(params, "requestedSchema", None) + if requested_schema: # Generate test data based on the schema for round-trip verification - properties = params.requestedSchema.get("properties", {}) + properties = requested_schema.get("properties", {}) content: dict[str, Any] = {} # Provide test values for each field diff --git a/tests/integration/elicitation/testing_handlers.py b/tests/integration/elicitation/testing_handlers.py index 22877d95e..e8fc73733 100644 --- a/tests/integration/elicitation/testing_handlers.py +++ b/tests/integration/elicitation/testing_handlers.py @@ -29,9 +29,10 @@ async def auto_accept_test_handler( """ logger.info(f"Auto-accept test handler called: {params.message}") - if params.requestedSchema: + requested_schema = getattr(params, "requestedSchema", None) + if requested_schema: # Generate realistic test data based on schema - content = _generate_test_response(params.requestedSchema) + content = _generate_test_response(requested_schema) return ElicitResult(action="accept", content=content) else: return ElicitResult(action="accept", content={"response": "auto-test-response"}) diff --git a/tests/integration/mcp_ui/test_mcp_ui_integration.py b/tests/integration/mcp_ui/test_mcp_ui_integration.py index ca4c5a0df..cf4acd162 100644 --- a/tests/integration/mcp_ui/test_mcp_ui_integration.py +++ b/tests/integration/mcp_ui/test_mcp_ui_integration.py @@ -23,6 +23,7 @@ async def passthrough_agent(tmp_path): core = Core() await core.initialize() # Avoid auto-opening browser windows during tests + assert core.context.config is not None core.context.config.mcp_ui_mode = "enabled" agent = McpAgentWithUI( diff --git a/tests/integration/roots/root_client.py b/tests/integration/roots/root_client.py index cfa0e1093..e7deb5bd8 100644 --- a/tests/integration/roots/root_client.py +++ b/tests/integration/roots/root_client.py @@ -2,7 +2,7 @@ from mcp.client.session import ClientSession from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.types import ListRootsResult, Root -from pydantic import AnyUrl +from pydantic import FileUrl async def list_roots_callback(context): @@ -10,11 +10,11 @@ async def list_roots_callback(context): return ListRootsResult( roots=[ Root( - uri=AnyUrl("file://foo/bar"), + uri=FileUrl("file://foo/bar"), name="Home Directory", ), Root( - uri=AnyUrl("file:///tmp"), + uri=FileUrl("file:///tmp"), name="Temp Directory", ), ] diff --git a/tests/unit/acp/test_content_conversion.py b/tests/unit/acp/test_content_conversion.py index 53b76e5ca..41b72ff24 100644 --- a/tests/unit/acp/test_content_conversion.py +++ b/tests/unit/acp/test_content_conversion.py @@ -5,6 +5,7 @@ from __future__ import annotations import base64 +from typing import TYPE_CHECKING, cast from acp.schema import ( BlobResourceContents, @@ -34,6 +35,9 @@ convert_acp_prompt_to_mcp_content_blocks, ) +if TYPE_CHECKING: + from acp.helpers import ContentBlock as ACPContentBlock + class TestTextContentConversion: """Test conversion of TextContentBlock.""" @@ -77,7 +81,7 @@ def test_basic_image_conversion(self): acp_image = ImageContentBlock( type="image", data=image_data, - mimeType="image/png", + mime_type="image/png", ) mcp_content = convert_acp_content_to_mcp(acp_image) @@ -94,7 +98,7 @@ def test_image_with_uri(self): acp_image = ImageContentBlock( type="image", data=image_data, - mimeType="image/jpeg", + mime_type="image/jpeg", uri="file:///path/to/image.jpg", ) @@ -114,7 +118,7 @@ def test_text_resource_conversion(self): type="resource", resource=TextResourceContents( uri="file:///path/to/file.py", - mimeType="text/x-python", + mime_type="text/x-python", text="def hello():\n print('Hello')", ), ) @@ -136,7 +140,7 @@ def test_blob_resource_conversion(self): type="resource", resource=BlobResourceContents( uri="file:///path/to/file.pdf", - mimeType="application/pdf", + mime_type="application/pdf", blob=blob_data, ), ) @@ -182,7 +186,7 @@ def test_mixed_content_prompt(self): type="resource", resource=TextResourceContents( uri="file:///main.py", - mimeType="text/x-python", + mime_type="text/x-python", text="print('hello')", ), ), @@ -190,7 +194,7 @@ def test_mixed_content_prompt(self): ImageContentBlock( type="image", data=image_data, - mimeType="image/png", + mime_type="image/png", ), ] @@ -223,6 +227,8 @@ def test_text_only_prompt(self): assert len(mcp_blocks) == 2 assert all(isinstance(block, MCPTextContent) for block in mcp_blocks) + assert isinstance(mcp_blocks[0], MCPTextContent) + assert isinstance(mcp_blocks[1], MCPTextContent) assert mcp_blocks[0].text == "First message" assert mcp_blocks[1].text == "Second message" @@ -238,7 +244,7 @@ class UnsupportedContent: type = "audio" data = "base64-audio-data" - result = convert_acp_content_to_mcp(UnsupportedContent()) + result = convert_acp_content_to_mcp(cast("ACPContentBlock", UnsupportedContent())) assert result is None def test_prompt_with_unsupported_content_skips_it(self): @@ -247,9 +253,9 @@ def test_prompt_with_unsupported_content_skips_it(self): class UnsupportedContent: type = "audio" - acp_prompt = [ + acp_prompt: list[ACPContentBlock] = [ TextContentBlock(type="text", text="Hello"), - UnsupportedContent(), + cast("ACPContentBlock", UnsupportedContent()), TextContentBlock(type="text", text="World"), ] @@ -258,5 +264,7 @@ class UnsupportedContent: # Should only have the two text blocks assert len(mcp_blocks) == 2 assert all(isinstance(block, MCPTextContent) for block in mcp_blocks) + assert isinstance(mcp_blocks[0], MCPTextContent) + assert isinstance(mcp_blocks[1], MCPTextContent) assert mcp_blocks[0].text == "Hello" assert mcp_blocks[1].text == "World" diff --git a/tests/unit/fast_agent/agents/test_agent_types.py b/tests/unit/fast_agent/agents/test_agent_types.py index f91b4a7de..72c79d4ee 100644 --- a/tests/unit/fast_agent/agents/test_agent_types.py +++ b/tests/unit/fast_agent/agents/test_agent_types.py @@ -41,6 +41,7 @@ def test_instruction_propagates_to_default_request_params(): ) # The instruction should be propagated to default_request_params.systemPrompt + assert config.default_request_params is not None assert config.default_request_params.systemPrompt == instruction, ( f"Expected systemPrompt to be '{instruction}', " f"but got {config.default_request_params.systemPrompt}" @@ -77,6 +78,7 @@ def test_instruction_takes_precedence_over_systemPrompt(): ) # The AgentConfig.instruction should take precedence over systemPrompt in RequestParams + assert config.default_request_params is not None assert config.default_request_params.systemPrompt == instruction, ( f"Expected AgentConfig.instruction ('{instruction}') to override " f"RequestParams.systemPrompt ('{original_system_prompt}'), " diff --git a/tests/unit/fast_agent/agents/test_llm_content_filter.py b/tests/unit/fast_agent/agents/test_llm_content_filter.py index db5ed72d8..b2af19375 100644 --- a/tests/unit/fast_agent/agents/test_llm_content_filter.py +++ b/tests/unit/fast_agent/agents/test_llm_content_filter.py @@ -66,7 +66,10 @@ def message_history(self): def model_info(self): from fast_agent.llm.model_info import ModelInfo - return ModelInfo.from_name(self.model_name, self.provider) + model_name = self.model_name + if model_name is None: + return None + return ModelInfo.from_name(model_name, self.provider) def make_decorator(model_name: str = "passthrough") -> tuple[LlmDecorator, RecordingStubLLM]: @@ -153,6 +156,7 @@ async def test_removes_unsupported_tool_result_content(): assert stub.generated_messages is not None sent_message = stub.generated_messages[0] + assert sent_message.tool_results is not None sanitized_result = sent_message.tool_results["tool1"] # Should have placeholder text since all content was removed assert len(sanitized_result.content) == 1 diff --git a/tests/unit/fast_agent/agents/test_mcp_agent_local_tools.py b/tests/unit/fast_agent/agents/test_mcp_agent_local_tools.py index dfae7ceb8..bdef12587 100644 --- a/tests/unit/fast_agent/agents/test_mcp_agent_local_tools.py +++ b/tests/unit/fast_agent/agents/test_mcp_agent_local_tools.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING, Any import pytest +from mcp.types import TextContent from fast_agent.agents.agent_types import AgentConfig from fast_agent.agents.mcp_agent import McpAgent @@ -42,6 +43,10 @@ def __init__(self) -> None: result: CallToolResult = await agent.call_tool("sample_tool", {"video_id": "1234"}) assert not result.isError assert calls == [{"video_id": "1234"}] - assert [block.text for block in result.content or []] == ["transcript for 1234"] + assert result.content is not None + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "transcript for 1234" await agent._aggregator.close() diff --git a/tests/unit/fast_agent/agents/test_mcp_agent_skills.py b/tests/unit/fast_agent/agents/test_mcp_agent_skills.py index b4a786261..fff33fb3a 100644 --- a/tests/unit/fast_agent/agents/test_mcp_agent_skills.py +++ b/tests/unit/fast_agent/agents/test_mcp_agent_skills.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest +from mcp.types import TextContent from fast_agent.agents.agent_types import AgentConfig from fast_agent.agents.mcp_agent import McpAgent @@ -73,6 +74,9 @@ async def test_skill_reader_rejects_relative_path(tmp_path: Path) -> None: result = await reader.execute({"path": "skills/alpha/SKILL.md"}) assert result.isError is True + assert result.content is not None + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) assert "Path must be absolute" in result.content[0].text @@ -91,6 +95,9 @@ async def test_skill_reader_blocks_outside_skill_directory(tmp_path: Path) -> No result = await reader.execute({"path": str(outside_file)}) assert result.isError is True + assert result.content is not None + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) assert "not within an allowed skill directory" in result.content[0].text @@ -107,7 +114,10 @@ async def test_skill_reader_reads_valid_skill_file(tmp_path: Path) -> None: result = await reader.execute({"path": str(skill_file)}) assert result.isError is False - assert any("Alpha body" in block.text for block in result.content) + assert result.content is not None + assert any( + isinstance(block, TextContent) and "Alpha body" in block.text for block in result.content + ) @pytest.mark.asyncio diff --git a/tests/unit/fast_agent/agents/workflow/test_agents_as_tools_agent.py b/tests/unit/fast_agent/agents/workflow/test_agents_as_tools_agent.py index 79cbe6b88..81b30332e 100644 --- a/tests/unit/fast_agent/agents/workflow/test_agents_as_tools_agent.py +++ b/tests/unit/fast_agent/agents/workflow/test_agents_as_tools_agent.py @@ -1,9 +1,10 @@ import asyncio +from collections.abc import Sequence from unittest.mock import AsyncMock import pytest from mcp import CallToolRequest, Tool -from mcp.types import CallToolRequestParams +from mcp.types import CallToolRequestParams, PromptMessage, TextContent from fast_agent.agents.agent_types import AgentConfig from fast_agent.agents.llm_agent import LlmAgent @@ -13,7 +14,7 @@ ) from fast_agent.constants import FAST_AGENT_ERROR_CHANNEL from fast_agent.mcp.helpers.content_helpers import text_content -from fast_agent.types import PromptMessageExtended +from fast_agent.types import PromptMessageExtended, RequestParams class FakeChildAgent(LlmAgent): @@ -24,7 +25,15 @@ def __init__(self, name: str, response_text: str = "ok", delay: float = 0): self._response_text = response_text self._delay = delay - async def generate(self, messages, request_params=None): + async def generate( + self, + messages: str + | PromptMessage + | PromptMessageExtended + | Sequence[str | PromptMessage | PromptMessageExtended], + request_params: RequestParams | None = None, + tools: list[Tool] | None = None, + ) -> PromptMessageExtended: if self._delay: await asyncio.sleep(self._delay) return PromptMessageExtended( @@ -39,7 +48,15 @@ async def spawn_detached_instance(self, name: str | None = None): class ErrorChannelChild(FakeChildAgent): - async def generate(self, messages, request_params=None): + async def generate( + self, + messages: str + | PromptMessage + | PromptMessageExtended + | Sequence[str | PromptMessage | PromptMessageExtended], + request_params: RequestParams | None = None, + tools: list[Tool] | None = None, + ) -> PromptMessageExtended: return PromptMessageExtended( role="assistant", content=[], @@ -50,7 +67,15 @@ async def generate(self, messages, request_params=None): class StubNestedAgentsAsTools(AgentsAsToolsAgent): """Stub AgentsAsToolsAgent that responds without hitting an LLM.""" - async def generate(self, messages, request_params=None): + async def generate( + self, + messages: str + | PromptMessage + | PromptMessageExtended + | Sequence[str | PromptMessage | PromptMessageExtended], + request_params: RequestParams | None = None, + tools: list[Tool] | None = None, + ) -> PromptMessageExtended: return PromptMessageExtended( role="assistant", content=[text_content(f"{self.name}-reply")], @@ -69,7 +94,7 @@ async def test_list_tools_merges_base_and_child(): # Inject a base MCP tool via the filtered MCP path to ensure merge behavior. base_tool = Tool(name="base_tool", description="base", inputSchema={"type": "object"}) - agent._get_filtered_mcp_tools = AsyncMock(return_value=[base_tool]) + setattr(agent, "_get_filtered_mcp_tools", AsyncMock(return_value=[base_tool])) result = await agent.list_tools() tool_names = {t.name for t in result.tools} @@ -94,7 +119,7 @@ async def test_run_tools_respects_max_parallel_and_timeout(): request = PromptMessageExtended(role="assistant", content=[], tool_calls=tool_calls) result_message = await agent.run_tools(request) - assert result_message.tool_results + assert result_message.tool_results is not None fast_result = result_message.tool_results["1"] slow_result = result_message.tool_results["2"] @@ -102,6 +127,9 @@ async def test_run_tools_respects_max_parallel_and_timeout(): assert not fast_result.isError # Skipped due to max_parallel cap. assert slow_result.isError + assert slow_result.content is not None + assert slow_result.content[0].type == "text" + assert isinstance(slow_result.content[0], TextContent) assert "Skipped" in slow_result.content[0].text # Now ensure timeout path yields an error result when a single slow call runs. @@ -111,9 +139,14 @@ async def test_run_tools_respects_max_parallel_and_timeout(): tool_calls={"3": CallToolRequest(params=CallToolRequestParams(name="agent__slow", arguments={"text": "hi"}))}, ) single_result = await agent.run_tools(request_single) + assert single_result.tool_results is not None err_res = single_result.tool_results["3"] assert err_res.isError - assert any("Tool execution failed" in (block.text or "") for block in err_res.content) + assert err_res.content is not None + assert any( + isinstance(block, TextContent) and "Tool execution failed" in (block.text or "") + for block in err_res.content + ) @pytest.mark.asyncio @@ -125,7 +158,8 @@ async def test_invoke_child_appends_error_channel(): call_result = await agent._invoke_child_agent(child, {"text": "hi"}) assert call_result.isError - texts = [block.text for block in call_result.content if hasattr(block, "text")] + assert call_result.content is not None + texts = [block.text for block in call_result.content if isinstance(block, TextContent)] assert "err-block" in texts @@ -144,7 +178,12 @@ async def test_nested_agents_as_tools_preserves_instance_labels(): request = PromptMessageExtended(role="assistant", content=[], tool_calls=tool_calls) result_message = await parent.run_tools(request) + assert result_message.tool_results is not None result = result_message.tool_results["1"] assert not result.isError # Reply should include the instance-suffixed nested agent name. - assert any("nested[1]-reply" in (block.text or "") for block in result.content) + assert result.content is not None + assert any( + isinstance(block, TextContent) and "nested[1]-reply" in (block.text or "") + for block in result.content + ) diff --git a/tests/unit/fast_agent/agents/workflow/test_router_unit.py b/tests/unit/fast_agent/agents/workflow/test_router_unit.py index a480190a8..ef9046cb1 100644 --- a/tests/unit/fast_agent/agents/workflow/test_router_unit.py +++ b/tests/unit/fast_agent/agents/workflow/test_router_unit.py @@ -79,7 +79,7 @@ async def test_single_agent_shortcircuit(): await router.initialize() # Test routing directly returns the single agent without LLM call - response, _ = await router._route_request("some request") + response, _ = await router._route_request(Prompt.user("some request")) # Verify result assert response diff --git a/tests/unit/fast_agent/commands/test_config_env_var.py b/tests/unit/fast_agent/commands/test_config_env_var.py index 7681de3bf..6046a96ab 100644 --- a/tests/unit/fast_agent/commands/test_config_env_var.py +++ b/tests/unit/fast_agent/commands/test_config_env_var.py @@ -47,7 +47,7 @@ def test_resolve_simple_env_var(temp_config_files): with patch.dict(os.environ, {"TEST_API_KEY": "actual_key_from_env"}): settings = get_settings(str(config_file)) - assert settings.api_key == "actual_key_from_env" + assert getattr(settings, "api_key") == "actual_key_from_env" def test_resolve_env_var_with_default_when_set(temp_config_files): @@ -57,7 +57,7 @@ def test_resolve_env_var_with_default_when_set(temp_config_files): with patch.dict(os.environ, {"SERVICE_URL": "http://env.url"}): settings = get_settings(str(config_file)) - assert settings.service_url == "http://env.url" + assert getattr(settings, "service_url") == "http://env.url" def test_resolve_env_var_with_default_when_not_set(temp_config_files): @@ -67,7 +67,7 @@ def test_resolve_env_var_with_default_when_not_set(temp_config_files): with patch.dict(os.environ, {}, clear=True): settings = get_settings(str(config_file)) - assert settings.service_url == "http://default.url" + assert getattr(settings, "service_url") == "http://default.url" def test_resolve_env_var_no_default_not_set(temp_config_files): @@ -77,7 +77,7 @@ def test_resolve_env_var_no_default_not_set(temp_config_files): with patch.dict(os.environ, {}, clear=True): settings = get_settings(str(config_file)) - assert settings.another_key == "${UNSET_KEY_NO_DEFAULT}" + assert getattr(settings, "another_key") == "${UNSET_KEY_NO_DEFAULT}" def test_nested_env_var_resolution(temp_config_files): @@ -93,9 +93,11 @@ def test_nested_env_var_resolution(temp_config_files): with patch.dict(os.environ, {"NESTED_ENV_VAR": "nested_from_env"}): settings = get_settings(str(config_file)) - assert settings.parent["child_plain"] == "value" - assert settings.parent["child_env"] == "nested_from_env" - assert settings.parent["child_env_default"] == "default_child_val" + parent = getattr(settings, "parent") + assert isinstance(parent, dict) + assert parent["child_plain"] == "value" + assert parent["child_env"] == "nested_from_env" + assert parent["child_env_default"] == "default_child_val" def test_env_var_in_list(temp_config_files): @@ -110,10 +112,11 @@ def test_env_var_in_list(temp_config_files): write_config(config_file, config_data) with patch.dict(os.environ, {"LIST_ITEM_ENV": "list_item_from_env"}): settings = get_settings(str(config_file)) - assert isinstance(settings.items, list) - assert settings.items[0] == "item1" - assert settings.items[1] == "list_item_from_env" - assert settings.items[2] == "default_list_item" + items = getattr(settings, "items") + assert isinstance(items, list) + assert items[0] == "item1" + assert items[1] == "list_item_from_env" + assert items[2] == "default_list_item" def test_mixed_config_and_secrets_with_env_vars(temp_config_files): @@ -134,10 +137,10 @@ def test_mixed_config_and_secrets_with_env_vars(temp_config_files): {"CONFIG_VAR": "env_config_val", "SECRET_ENV_KEY": "actual_secret"}, ): settings = get_settings(str(config_file)) - assert settings.general_setting == "from_config_file" - assert settings.config_env == "env_config_val" - assert settings.secret_key == "actual_secret" - assert settings.db_password == "default_db_pass" + assert getattr(settings, "general_setting") == "from_config_file" + assert getattr(settings, "config_env") == "env_config_val" + assert getattr(settings, "secret_key") == "actual_secret" + assert getattr(settings, "db_password") == "default_db_pass" def test_env_var_in_mcp_server_settings(temp_config_files): diff --git a/tests/unit/fast_agent/commands/test_url_parser.py b/tests/unit/fast_agent/commands/test_url_parser.py index 9a07ece00..9aa2fbc62 100644 --- a/tests/unit/fast_agent/commands/test_url_parser.py +++ b/tests/unit/fast_agent/commands/test_url_parser.py @@ -2,6 +2,8 @@ Unit tests for the URL parser utility functions. """ +from typing import Literal + import pytest from fast_agent.cli.commands.url_parser import ( @@ -115,7 +117,7 @@ def test_parse_server_urls_with_auth(self): def test_generate_server_configs(self): """Test generating server configurations from parsed URLs.""" - parsed_urls = [ + parsed_urls: list[tuple[str, Literal["http", "sse"], str, dict[str, str] | None]] = [ ("example_com", "http", "http://example.com/mcp", None), ("api_test_com", "sse", "https://api.test.com/sse", None), ] @@ -135,7 +137,7 @@ def test_generate_server_configs(self): def test_generate_server_configs_with_auth(self): """Test generating server configurations with auth headers.""" auth_headers = {"Authorization": "Bearer test_token_123"} - parsed_urls = [ + parsed_urls: list[tuple[str, Literal["http", "sse"], str, dict[str, str] | None]] = [ ("example_com", "http", "http://example.com/mcp", auth_headers), ("api_test_com", "sse", "https://api.test.com/sse", auth_headers), ] @@ -152,7 +154,7 @@ def test_generate_server_configs_with_auth(self): def test_generate_server_configs_with_name_collisions(self): """Test handling of server name collisions.""" # Create a list of parsed URLs with the same server name - parsed_urls = [ + parsed_urls: list[tuple[str, Literal["http", "sse"], str, dict[str, str] | None]] = [ ( "evalstate", "sse", diff --git a/tests/unit/fast_agent/commands/test_url_parser_hf_auth.py b/tests/unit/fast_agent/commands/test_url_parser_hf_auth.py index 4da088284..8990042cd 100644 --- a/tests/unit/fast_agent/commands/test_url_parser_hf_auth.py +++ b/tests/unit/fast_agent/commands/test_url_parser_hf_auth.py @@ -134,6 +134,7 @@ def test_multiple_urls_mixed_hf_and_non_hf(self): # First URL - HF with token server_name, transport_type, url, headers = result[0] assert server_name == "hf_co" + assert headers is not None assert headers["Authorization"] == "Bearer hf_test_token" # Second URL - non-HF, no token @@ -144,6 +145,7 @@ def test_multiple_urls_mixed_hf_and_non_hf(self): # Third URL - HF with token server_name, transport_type, url, headers = result[2] assert server_name == "huggingface_co" + assert headers is not None assert headers["Authorization"] == "Bearer hf_test_token" finally: _restore_hf_token(original) diff --git a/tests/unit/fast_agent/core/test_instruction_refresh.py b/tests/unit/fast_agent/core/test_instruction_refresh.py index f7385a329..84dd354b5 100644 --- a/tests/unit/fast_agent/core/test_instruction_refresh.py +++ b/tests/unit/fast_agent/core/test_instruction_refresh.py @@ -1,6 +1,7 @@ """Tests for instruction building and refresh utilities.""" import asyncio +from typing import TYPE_CHECKING, cast from fast_agent.core.instruction_refresh import ( McpInstructionCapable, @@ -9,6 +10,9 @@ rebuild_agent_instruction, ) +if TYPE_CHECKING: + from fast_agent.mcp.mcp_aggregator import MCPAggregator + class StubAggregator: """Stub aggregator that returns predefined server instructions.""" @@ -135,7 +139,7 @@ def test_build_instruction_with_context() -> None: def test_build_instruction_with_aggregator() -> None: template = "{{serverInstructions}}" aggregator = StubAggregator({"my-server": ("Be helpful", ["do_thing"])}) - result = asyncio.run(build_instruction(template, aggregator=aggregator)) + result = asyncio.run(build_instruction(template, aggregator=cast("MCPAggregator", aggregator))) assert "my-server" in result assert "Be helpful" in result diff --git a/tests/unit/fast_agent/core/test_prompt.py b/tests/unit/fast_agent/core/test_prompt.py index 27b47a36f..efb4378cd 100644 --- a/tests/unit/fast_agent/core/test_prompt.py +++ b/tests/unit/fast_agent/core/test_prompt.py @@ -7,12 +7,35 @@ import tempfile from pathlib import Path -from mcp.types import EmbeddedResource, ImageContent, PromptMessage, TextContent +from mcp.types import ( + EmbeddedResource, + ImageContent, + PromptMessage, + TextContent, + TextResourceContents, +) from fast_agent.core.prompt import Prompt from fast_agent.mcp.prompt_message_extended import PromptMessageExtended +def _text_block(block: object) -> TextContent: + assert isinstance(block, TextContent) + return block + + +def _embedded_text(block: object) -> str: + assert isinstance(block, EmbeddedResource) + resource = block.resource + assert isinstance(resource, TextResourceContents) + return resource.text + + +def _resource_text(resource: object) -> str: + assert isinstance(resource, TextResourceContents) + return resource.text + + def test_user_method(): """Test the Prompt.user method.""" # Test with simple text @@ -30,8 +53,8 @@ def test_user_method(): assert isinstance(message, PromptMessageExtended) assert message.role == "user" assert len(message.content) == 2 - assert message.content[0].text == "Hello," - assert message.content[1].text == "How are you?" + assert _text_block(message.content[0]).text == "Hello," + assert _text_block(message.content[1]).text == "How are you?" # Test with PromptMessage prompt_message = PromptMessage( @@ -42,7 +65,7 @@ def test_user_method(): assert isinstance(message, PromptMessageExtended) assert message.role == "user" # Role should be changed to user assert len(message.content) == 1 - assert message.content[0].text == "I'm a PromptMessage" + assert _text_block(message.content[0]).text == "I'm a PromptMessage" # Test with PromptMessageExtended multipart = Prompt.assistant("I'm a multipart message") @@ -51,7 +74,7 @@ def test_user_method(): assert isinstance(message, PromptMessageExtended) assert message.role == "user" # Role should be changed to user assert len(message.content) == 1 - assert message.content[0].text == "I'm a multipart message" + assert _text_block(message.content[0]).text == "I'm a multipart message" def test_assistant_method(): @@ -98,16 +121,15 @@ def test_with_file_paths(): assert message.role == "user" assert len(message.content) == 2 - assert message.content[0].text == "Check this file:" - assert isinstance(message.content[1], EmbeddedResource) - assert message.content[1].resource.text == "Hello, world!" + assert _text_block(message.content[0]).text == "Check this file:" + assert _embedded_text(message.content[1]) == "Hello, world!" # Test with image file message = Prompt.assistant("Here's the image:", Path(image_path)) assert message.role == "assistant" assert len(message.content) == 2 - assert message.content[0].text == "Here's the image:" + assert _text_block(message.content[0]).text == "Here's the image:" assert isinstance(message.content[1], ImageContent) # Decode the base64 data @@ -137,14 +159,14 @@ def test_with_file_paths(): assert len(message.content) == 2 # Using dictionary comparison because the objects might not be identity-equal assert message.content[1].type == embedded.type - assert message.content[1].resource.text == embedded.resource.text + assert _embedded_text(message.content[1]) == _resource_text(embedded.resource) # Test with ReadResourceResult resource_result = ReadResourceResult(contents=[text_resource]) message = Prompt.user("Resource result:", resource_result) assert message.role == "user" assert len(message.content) > 1 # Should have text + resource - assert message.content[0].text == "Resource result:" + assert _text_block(message.content[0]).text == "Resource result:" assert isinstance(message.content[1], EmbeddedResource) # Test with direct TextContent @@ -167,9 +189,9 @@ def test_with_file_paths(): message = Prompt.user("Text followed by:", text_content, "And an image:", image_content) assert message.role == "user" assert len(message.content) == 4 - assert message.content[0].text == "Text followed by:" + assert _text_block(message.content[0]).text == "Text followed by:" assert message.content[1] == text_content - assert message.content[2].text == "And an image:" + assert _text_block(message.content[2]).text == "And an image:" assert message.content[3] == image_content finally: @@ -201,6 +223,7 @@ def test_conversation_method(): assert len(mixed_conversation) == 3 assert mixed_conversation[0].role == "user" assert mixed_conversation[1].role == "assistant" + assert isinstance(mixed_conversation[1].content, TextContent) assert mixed_conversation[1].content.text == "Direct dict!" assert mixed_conversation[2].role == "user" diff --git a/tests/unit/fast_agent/llm/provider/anthropic/test_anthropic_cache_control.py b/tests/unit/fast_agent/llm/provider/anthropic/test_anthropic_cache_control.py index efdfe3242..da53ffd9a 100644 --- a/tests/unit/fast_agent/llm/provider/anthropic/test_anthropic_cache_control.py +++ b/tests/unit/fast_agent/llm/provider/anthropic/test_anthropic_cache_control.py @@ -1,3 +1,5 @@ + +from anthropic.types import MessageParam from mcp.types import TextContent from fast_agent.llm.provider.anthropic.cache_planner import AnthropicCachePlanner @@ -12,13 +14,16 @@ def make_message(text: str, *, is_template: bool = False) -> PromptMessageExtend ) -def count_cache_controls(messages: list[dict]) -> int: - return sum( - 1 - for msg in messages - for block in msg.get("content", []) - if isinstance(block, dict) and block.get("cache_control") - ) +def count_cache_controls(messages: list[MessageParam]) -> int: + total = 0 + for msg in messages: + content = msg.get("content", []) + if isinstance(content, str): + continue + for block in content: + if isinstance(block, dict) and block.get("cache_control"): + total += 1 + return total def test_template_cache_respects_budget(): @@ -35,8 +40,10 @@ def test_template_cache_respects_budget(): for idx in plan_indices: AnthropicLLM._apply_cache_control_to_message(provider_msgs[idx]) - assert "cache_control" in provider_msgs[0]["content"][-1] - assert "cache_control" in provider_msgs[1]["content"][-1] + first_blocks = list(provider_msgs[0]["content"]) + second_blocks = list(provider_msgs[1]["content"]) + assert "cache_control" in first_blocks[-1] + assert "cache_control" in second_blocks[-1] def test_conversation_cache_respects_four_block_limit(): diff --git a/tests/unit/fast_agent/llm/provider/anthropic/test_tool_id_sanitization.py b/tests/unit/fast_agent/llm/provider/anthropic/test_tool_id_sanitization.py index 1ffade622..d3b7b2068 100644 --- a/tests/unit/fast_agent/llm/provider/anthropic/test_tool_id_sanitization.py +++ b/tests/unit/fast_agent/llm/provider/anthropic/test_tool_id_sanitization.py @@ -20,7 +20,9 @@ def test_sanitizes_tool_use_ids_for_assistant_calls(): converted: MessageParam = AnthropicConverter.convert_to_anthropic(msg) assert converted["role"] == "assistant" - assert converted["content"][0]["id"] == expected + content_blocks = list(converted["content"]) + assert isinstance(content_blocks[0], dict) + assert content_blocks[0]["id"] == expected def test_sanitizes_tool_use_ids_for_tool_results(): @@ -33,4 +35,6 @@ def test_sanitizes_tool_use_ids_for_tool_results(): converted: MessageParam = AnthropicConverter.convert_to_anthropic(msg) assert converted["role"] == "user" - assert converted["content"][0]["tool_use_id"] == expected + content_blocks = list(converted["content"]) + assert isinstance(content_blocks[0], dict) + assert content_blocks[0]["tool_use_id"] == expected diff --git a/tests/unit/fast_agent/llm/providers/test_google_converter.py b/tests/unit/fast_agent/llm/providers/test_google_converter.py index 731efe2d5..7a0b17dff 100644 --- a/tests/unit/fast_agent/llm/providers/test_google_converter.py +++ b/tests/unit/fast_agent/llm/providers/test_google_converter.py @@ -36,9 +36,10 @@ def test_convert_function_results_to_google_text_only(): content = contents[0] assert isinstance(content, types.Content) assert content.role == "tool" - assert content.parts + parts = content.parts + assert parts is not None # First part should be a function response named 'weather' - fn_resp = content.parts[0].function_response + fn_resp = parts[0].function_response assert fn_resp is not None assert fn_resp.name == "weather" assert isinstance(fn_resp.response, dict) @@ -78,7 +79,7 @@ def test_convert_video_resource(): resource = EmbeddedResource( type="resource", resource=BlobResourceContents( - uri="file:///path/to/video.mp4", + uri=AnyUrl("file:///path/to/video.mp4"), mimeType="video/mp4", blob=encoded_video ) @@ -99,8 +100,10 @@ def test_convert_video_resource(): content = contents[0] assert isinstance(content, types.Content) - assert len(content.parts) == 1 - part = content.parts[0] + parts = content.parts + assert parts is not None + assert len(parts) == 1 + part = parts[0] # Check if it's an inline data part assert part.inline_data is not None @@ -117,7 +120,7 @@ def test_convert_mixed_content_video_text(): video_resource = EmbeddedResource( type="resource", resource=BlobResourceContents( - uri="file:///video.mp4", + uri=AnyUrl("file:///video.mp4"), mimeType="video/mp4", blob=encoded_video ) @@ -138,14 +141,16 @@ def test_convert_mixed_content_video_text(): # Verify assert len(contents) == 1 content = contents[0] - assert len(content.parts) == 2 + parts = content.parts + assert parts is not None + assert len(parts) == 2 # First part should be video - assert content.parts[0].inline_data is not None - assert content.parts[0].inline_data.mime_type == "video/mp4" + assert parts[0].inline_data is not None + assert parts[0].inline_data.mime_type == "video/mp4" # Second part should be text - assert content.parts[1].text == "Describe this video" + assert parts[1].text == "Describe this video" def test_convert_youtube_url_video(): @@ -172,8 +177,10 @@ def test_convert_youtube_url_video(): # Verify assert len(contents) == 1 content = contents[0] - assert len(content.parts) == 1 - part = content.parts[0] + parts = content.parts + assert parts is not None + assert len(parts) == 1 + part = parts[0] # Should use file_data for YouTube URLs assert part.file_data is not None @@ -193,8 +200,10 @@ def test_convert_resource_link_video(): assert len(contents) == 1 content = contents[0] - assert len(content.parts) == 1 - part = content.parts[0] + parts = content.parts + assert parts is not None + assert len(parts) == 1 + part = parts[0] # Should use file_data for video ResourceLink assert part.file_data is not None @@ -214,8 +223,10 @@ def test_convert_resource_link_image(): assert len(contents) == 1 content = contents[0] - assert len(content.parts) == 1 - part = content.parts[0] + parts = content.parts + assert parts is not None + assert len(parts) == 1 + part = parts[0] # Should use file_data for image ResourceLink assert part.file_data is not None @@ -235,8 +246,10 @@ def test_convert_resource_link_audio(): assert len(contents) == 1 content = contents[0] - assert len(content.parts) == 1 - part = content.parts[0] + parts = content.parts + assert parts is not None + assert len(parts) == 1 + part = parts[0] # Should use file_data for audio ResourceLink assert part.file_data is not None @@ -260,8 +273,10 @@ def test_convert_resource_link_text_fallback(): assert len(contents) == 1 content = contents[0] - assert len(content.parts) == 1 - part = content.parts[0] + parts = content.parts + assert parts is not None + assert len(parts) == 1 + part = parts[0] # Should use text for non-media ResourceLink assert part.text is not None @@ -286,11 +301,14 @@ def test_convert_resource_link_in_tool_result(): assert content.role == "tool" # Should have function response part and media part - assert len(content.parts) >= 1 + parts = content.parts + assert parts is not None + assert len(parts) >= 1 # Check for the media part (video) - media_parts = [p for p in content.parts if p.file_data is not None] + media_parts = [p for p in parts if p.file_data is not None] assert len(media_parts) == 1 + assert media_parts[0].file_data is not None assert media_parts[0].file_data.file_uri == "https://storage.example.com/output.mp4" assert media_parts[0].file_data.mime_type == "video/mp4" @@ -315,7 +333,11 @@ def test_convert_resource_link_text_in_tool_result(): assert content.role == "tool" # Should have function response part with text content - fn_resp = content.parts[0].function_response + parts = content.parts + assert parts is not None + fn_resp = parts[0].function_response assert fn_resp is not None - assert "text_content" in fn_resp.response - assert "config_file" in fn_resp.response["text_content"] + response = fn_resp.response + assert isinstance(response, dict) + assert "text_content" in response + assert "config_file" in response["text_content"] diff --git a/tests/unit/fast_agent/llm/providers/test_llm_anthropic_caching.py b/tests/unit/fast_agent/llm/providers/test_llm_anthropic_caching.py index f5e5f58d1..dfc7cc274 100644 --- a/tests/unit/fast_agent/llm/providers/test_llm_anthropic_caching.py +++ b/tests/unit/fast_agent/llm/providers/test_llm_anthropic_caching.py @@ -5,7 +5,10 @@ to verify cache_control markers are applied correctly based on cache_mode settings. """ +from typing import Literal + import pytest +from anthropic.types import MessageParam from mcp.types import CallToolResult, TextContent from fast_agent.config import AnthropicSettings, Settings @@ -20,7 +23,9 @@ class TestAnthropicCaching: """Test cases for Anthropic caching functionality.""" - def _create_context_with_cache_mode(self, cache_mode: str) -> Context: + def _create_context_with_cache_mode( + self, cache_mode: Literal["off", "prompt", "auto"] + ) -> Context: """Create a context with specified cache mode.""" ctx = Context() ctx.config = Settings() @@ -29,15 +34,20 @@ def _create_context_with_cache_mode(self, cache_mode: str) -> Context: ) return ctx - def _create_llm(self, cache_mode: str = "off") -> AnthropicLLM: + def _create_llm( + self, cache_mode: Literal["off", "prompt", "auto"] = "off" + ) -> AnthropicLLM: """Create an AnthropicLLM instance with specified cache mode.""" ctx = self._create_context_with_cache_mode(cache_mode) llm = AnthropicLLM(context=ctx) return llm def _apply_cache_plan( - self, messages: list[PromptMessageExtended], cache_mode: str, system_blocks: int = 0 - ) -> list[dict]: + self, + messages: list[PromptMessageExtended], + cache_mode: Literal["off", "prompt", "auto"], + system_blocks: int = 0, + ) -> list[MessageParam]: planner = AnthropicCachePlanner() plan = planner.plan_indices(messages, cache_mode=cache_mode, system_cache_blocks=system_blocks) converted = [AnthropicConverter.convert_to_anthropic(m) for m in messages] @@ -281,7 +291,8 @@ def test_cache_control_on_last_content_block(self): converted = self._apply_cache_plan(template_msgs, cache_mode="prompt") # Cache control should be on the last block - content_blocks = converted[0]["content"] + content = converted[0].get("content", []) + content_blocks = [] if isinstance(content, str) else list(content) assert len(content_blocks) == 2 # First block should NOT have cache_control @@ -291,8 +302,7 @@ def test_cache_control_on_last_content_block(self): # At least one block should have cache_control found_cache_control = any( - isinstance(block, dict) and "cache_control" in block - for block in content_blocks + isinstance(block, dict) and "cache_control" in block for block in content_blocks ) assert found_cache_control, "Template should have cache_control" diff --git a/tests/unit/fast_agent/llm/providers/test_llm_google_vertex.py b/tests/unit/fast_agent/llm/providers/test_llm_google_vertex.py index f434ab5f4..3f8da1a7d 100644 --- a/tests/unit/fast_agent/llm/providers/test_llm_google_vertex.py +++ b/tests/unit/fast_agent/llm/providers/test_llm_google_vertex.py @@ -14,7 +14,11 @@ def _build_llm(config: Settings) -> GoogleNativeLLM: def test_vertex_cfg_accepts_model_object_and_resolves_preview_names() -> None: """Vertex config may arrive as a pydantic model with a custom attr object.""" google_settings = GoogleSettings() - google_settings.vertex_ai = types.SimpleNamespace(enabled=True, project_id="proj", location="loc") + setattr( + google_settings, + "vertex_ai", + types.SimpleNamespace(enabled=True, project_id="proj", location="loc"), + ) config = Settings(google=google_settings) llm = _build_llm(config) diff --git a/tests/unit/fast_agent/llm/providers/test_llm_tensorzero_unit.py b/tests/unit/fast_agent/llm/providers/test_llm_tensorzero_unit.py index a7ac222e1..151a450dc 100644 --- a/tests/unit/fast_agent/llm/providers/test_llm_tensorzero_unit.py +++ b/tests/unit/fast_agent/llm/providers/test_llm_tensorzero_unit.py @@ -1,12 +1,15 @@ +from typing import TYPE_CHECKING, cast from unittest.mock import MagicMock, patch import pytest -from openai.types.chat import ChatCompletionMessageParam, ChatCompletionSystemMessageParam from fast_agent.agents import McpAgent from fast_agent.llm.provider.openai.llm_tensorzero_openai import TensorZeroOpenAILLM from fast_agent.llm.request_params import RequestParams +if TYPE_CHECKING: + from openai.types.chat import ChatCompletionMessageParam, ChatCompletionSystemMessageParam + # --- Fixtures --- @@ -87,8 +90,9 @@ def test_prepare_api_request_with_template_vars(mock_super_prepare, t0_llm): @patch("fast_agent.llm.provider.openai.llm_openai.OpenAILLM._prepare_api_request") def test_prepare_api_request_merges_metadata(mock_super_prepare, t0_llm): """Tests merging of tensorzero_arguments from metadata.""" - initial_system_message = ChatCompletionSystemMessageParam( - role="system", content=[{"var1": "original"}] + initial_system_message = cast( + "ChatCompletionSystemMessageParam", + {"role": "system", "content": [{"var1": "original"}]}, ) messages: list[ChatCompletionMessageParam] = [initial_system_message] mock_super_prepare.return_value = {"model": "test_chat", "messages": messages} @@ -116,8 +120,9 @@ def test_prepare_api_request_adds_episode_id(mock_super_prepare, t0_llm): @patch("fast_agent.llm.provider.openai.llm_openai.OpenAILLM._prepare_api_request") def test_prepare_api_request_all_features(mock_super_prepare, t0_llm): """Tests all features working together.""" - initial_system_message = ChatCompletionSystemMessageParam( - role="system", content="Original prompt" + initial_system_message = cast( + "ChatCompletionSystemMessageParam", + {"role": "system", "content": "Original prompt"}, ) messages: list[ChatCompletionMessageParam] = [initial_system_message] mock_super_prepare.return_value = { diff --git a/tests/unit/fast_agent/llm/providers/test_multipart_converter_anthropic.py b/tests/unit/fast_agent/llm/providers/test_multipart_converter_anthropic.py index c405bb8be..f51ad3836 100644 --- a/tests/unit/fast_agent/llm/providers/test_multipart_converter_anthropic.py +++ b/tests/unit/fast_agent/llm/providers/test_multipart_converter_anthropic.py @@ -1,6 +1,8 @@ import base64 import json import unittest +from collections.abc import Iterable, Mapping +from typing import cast from mcp.types import ( BlobResourceContents, @@ -25,9 +27,36 @@ PDF_BASE64 = base64.b64encode(b"fake_pdf_data").decode("utf-8") +def content_blocks(message: Mapping[str, object]) -> list[dict[str, object]]: + content = message.get("content", []) + if isinstance(content, str): + return [] + if not isinstance(content, Iterable): + return [] + filtered = [block for block in content if isinstance(block, dict)] + return cast("list[dict[str, object]]", filtered) + + +def block_source(block: dict[str, object]) -> dict[str, object]: + source = block.get("source") + assert isinstance(source, dict) + return cast("dict[str, object]", source) + + +def block_content(block: dict[str, object]) -> list[dict[str, object]]: + content = block.get("content", []) + if isinstance(content, str): + return [] + if not isinstance(content, Iterable): + return [] + filtered = [item for item in content if isinstance(item, dict)] + return cast("list[dict[str, object]]", filtered) + + + def create_pdf_resource(pdf_base64) -> EmbeddedResource: pdf_resource: BlobResourceContents = BlobResourceContents( - uri="test://example.com/document.pdf", + uri=AnyUrl("test://example.com/document.pdf"), mimeType="application/pdf", blob=pdf_base64, ) @@ -53,9 +82,9 @@ def test_text_content_conversion(self): # Assertions - using dictionary access, not attribute access self.assertEqual(anthropic_msg["role"], "user") - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "text") - self.assertEqual(anthropic_msg["content"][0]["text"], self.sample_text) + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "text") + self.assertEqual(content_blocks(anthropic_msg)[0]["text"], self.sample_text) def test_image_content_conversion(self): """Test conversion of ImageContent to Anthropic image block.""" @@ -70,17 +99,17 @@ def test_image_content_conversion(self): # Assertions - using dictionary access self.assertEqual(anthropic_msg["role"], "user") - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "image") - self.assertEqual(anthropic_msg["content"][0]["source"]["type"], "base64") - self.assertEqual(anthropic_msg["content"][0]["source"]["media_type"], "image/jpeg") - self.assertEqual(anthropic_msg["content"][0]["source"]["data"], self.sample_image_base64) + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "image") + self.assertEqual(block_source(content_blocks(anthropic_msg)[0])["type"], "base64") + self.assertEqual(block_source(content_blocks(anthropic_msg)[0])["media_type"], "image/jpeg") + self.assertEqual(block_source(content_blocks(anthropic_msg)[0])["data"], self.sample_image_base64) def test_embedded_resource_text_conversion(self): """Test conversion of text-based EmbeddedResource to Anthropic document block.""" # Create a text resource text_resource = TextResourceContents( - uri="test://example.com/document.txt", + uri=AnyUrl("test://example.com/document.txt"), mimeType="text/plain", text=self.sample_text, ) @@ -92,12 +121,12 @@ def test_embedded_resource_text_conversion(self): # Assertions - using dictionary access self.assertEqual(anthropic_msg["role"], "user") - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "document") - self.assertEqual(anthropic_msg["content"][0]["source"]["type"], "text") - self.assertEqual(anthropic_msg["content"][0]["title"], "document.txt") - self.assertEqual(anthropic_msg["content"][0]["source"]["media_type"], "text/plain") - self.assertEqual(anthropic_msg["content"][0]["source"]["data"], self.sample_text) + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "document") + self.assertEqual(block_source(content_blocks(anthropic_msg)[0])["type"], "text") + self.assertEqual(content_blocks(anthropic_msg)[0]["title"], "document.txt") + self.assertEqual(block_source(content_blocks(anthropic_msg)[0])["media_type"], "text/plain") + self.assertEqual(block_source(content_blocks(anthropic_msg)[0])["data"], self.sample_text) def test_embedded_resource_pdf_conversion(self): """Test conversion of PDF EmbeddedResource to Anthropic document block.""" @@ -110,17 +139,17 @@ def test_embedded_resource_pdf_conversion(self): # Assertions - using dictionary access self.assertEqual(anthropic_msg["role"], "user") - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "document") - self.assertEqual(anthropic_msg["content"][0]["source"]["type"], "base64") - self.assertEqual(anthropic_msg["content"][0]["source"]["media_type"], "application/pdf") - self.assertEqual(anthropic_msg["content"][0]["source"]["data"], PDF_BASE64) + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "document") + self.assertEqual(block_source(content_blocks(anthropic_msg)[0])["type"], "base64") + self.assertEqual(block_source(content_blocks(anthropic_msg)[0])["media_type"], "application/pdf") + self.assertEqual(block_source(content_blocks(anthropic_msg)[0])["data"], PDF_BASE64) def test_embedded_resource_image_url_conversion(self): """Test conversion of image URL in EmbeddedResource to Anthropic image block.""" # Create an image resource with URL image_resource = BlobResourceContents( - uri="https://example.com/image.jpg", + uri=AnyUrl("https://example.com/image.jpg"), mimeType="image/jpeg", blob=self.sample_image_base64, # This should be ignored for URL ) @@ -132,11 +161,11 @@ def test_embedded_resource_image_url_conversion(self): # Assertions - using dictionary access self.assertEqual(anthropic_msg["role"], "user") - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "image") - self.assertEqual(anthropic_msg["content"][0]["source"]["type"], "url") + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "image") + self.assertEqual(block_source(content_blocks(anthropic_msg)[0])["type"], "url") self.assertEqual( - anthropic_msg["content"][0]["source"]["url"], + block_source(content_blocks(anthropic_msg)[0])["url"], "https://example.com/image.jpg", ) @@ -154,9 +183,9 @@ def test_assistant_role_restrictions(self): # Assertions - only text should remain self.assertEqual(anthropic_msg["role"], "assistant") - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "text") - self.assertEqual(anthropic_msg["content"][0]["text"], self.sample_text) + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "text") + self.assertEqual(content_blocks(anthropic_msg)[0]["text"], self.sample_text) def test_multiple_content_blocks(self): """Test conversion of messages with multiple content blocks.""" @@ -176,12 +205,12 @@ def test_multiple_content_blocks(self): # Assertions - using dictionary access self.assertEqual(anthropic_msg["role"], "user") - self.assertEqual(len(anthropic_msg["content"]), 3) - self.assertEqual(anthropic_msg["content"][0]["type"], "text") - self.assertEqual(anthropic_msg["content"][0]["text"], "First text") - self.assertEqual(anthropic_msg["content"][1]["type"], "image") - self.assertEqual(anthropic_msg["content"][2]["type"], "text") - self.assertEqual(anthropic_msg["content"][2]["text"], "Second text") + self.assertEqual(len(content_blocks(anthropic_msg)), 3) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "text") + self.assertEqual(content_blocks(anthropic_msg)[0]["text"], "First text") + self.assertEqual(content_blocks(anthropic_msg)[1]["type"], "image") + self.assertEqual(content_blocks(anthropic_msg)[2]["type"], "text") + self.assertEqual(content_blocks(anthropic_msg)[2]["text"], "Second text") def test_unsupported_mime_type_handling(self): """Test handling of unsupported MIME types.""" @@ -198,21 +227,21 @@ def test_unsupported_mime_type_handling(self): anthropic_msg = AnthropicConverter.convert_to_anthropic(multipart) # Should have kept the text content and added a fallback text for the image - self.assertEqual(len(anthropic_msg["content"]), 2) - self.assertEqual(anthropic_msg["content"][0]["type"], "text") - self.assertEqual(anthropic_msg["content"][0]["text"], "This is some text") - self.assertEqual(anthropic_msg["content"][1]["type"], "text") - self.assertIn( - "Image with unsupported format 'image/bmp'", - anthropic_msg["content"][1]["text"], - ) + self.assertEqual(len(content_blocks(anthropic_msg)), 2) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "text") + self.assertEqual(content_blocks(anthropic_msg)[0]["text"], "This is some text") + self.assertEqual(content_blocks(anthropic_msg)[1]["type"], "text") + fallback_text = content_blocks(anthropic_msg)[1]["text"] + self.assertIsInstance(fallback_text, str) + assert isinstance(fallback_text, str) + self.assertIn("Image with unsupported format 'image/bmp'", fallback_text) def test_svg_resource_conversion(self): """Test handling of SVG resources - should convert to code block.""" # Create an embedded SVG resource svg_content = '' svg_resource = TextResourceContents( - uri="test://example.com/image.svg", + uri=AnyUrl("test://example.com/image.svg"), mimeType="image/svg+xml", text=svg_content, ) @@ -223,10 +252,13 @@ def test_svg_resource_conversion(self): anthropic_msg = AnthropicConverter.convert_to_anthropic(multipart) # Should be converted to a text block with the SVG code - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "text") - self.assertIn("```xml", anthropic_msg["content"][0]["text"]) - self.assertIn(svg_content, anthropic_msg["content"][0]["text"]) + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "text") + svg_text = content_blocks(anthropic_msg)[0]["text"] + self.assertIsInstance(svg_text, str) + assert isinstance(svg_text, str) + self.assertIn("```xml", svg_text) + self.assertIn(svg_content, svg_text) def test_empty_content_list(self): """Test conversion with empty content list.""" @@ -237,13 +269,13 @@ def test_empty_content_list(self): # Should have empty content list self.assertEqual(anthropic_msg["role"], "user") - self.assertEqual(len(anthropic_msg["content"]), 0) + self.assertEqual(len(content_blocks(anthropic_msg)), 0) def test_embedded_resource_pdf_url_conversion(self): """Test conversion of PDF URL in EmbeddedResource to Anthropic document block.""" # Create a PDF resource with URL pdf_resource = BlobResourceContents( - uri="https://example.com/document.pdf", + uri=AnyUrl("https://example.com/document.pdf"), mimeType="application/pdf", blob=base64.b64encode(b"fake_pdf_data").decode("utf-8"), ) @@ -254,10 +286,10 @@ def test_embedded_resource_pdf_url_conversion(self): anthropic_msg = AnthropicConverter.convert_to_anthropic(multipart) # Assertions - using dictionary access - self.assertEqual(anthropic_msg["content"][0]["type"], "document") - self.assertEqual(anthropic_msg["content"][0]["source"]["type"], "url") + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "document") + self.assertEqual(block_source(content_blocks(anthropic_msg)[0])["type"], "url") self.assertEqual( - anthropic_msg["content"][0]["source"]["url"], + block_source(content_blocks(anthropic_msg)[0])["url"], "https://example.com/document.pdf", ) @@ -284,14 +316,16 @@ def test_mixed_content_with_unsupported_formats(self): anthropic_msg = AnthropicConverter.convert_to_anthropic(multipart) # Should have kept the text, created fallback for unsupported, and kept supported image - self.assertEqual(len(anthropic_msg["content"]), 3) - self.assertEqual(anthropic_msg["content"][0]["type"], "text") - self.assertEqual(anthropic_msg["content"][0]["text"], self.sample_text) + self.assertEqual(len(content_blocks(anthropic_msg)), 3) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "text") + self.assertEqual(content_blocks(anthropic_msg)[0]["text"], self.sample_text) self.assertEqual( - anthropic_msg["content"][1]["type"], "text" + content_blocks(anthropic_msg)[1]["type"], "text" ) # Fallback text for unsupported - self.assertEqual(anthropic_msg["content"][2]["type"], "image") # Supported image kept - self.assertEqual(anthropic_msg["content"][2]["source"]["media_type"], "image/jpeg") + self.assertEqual(content_blocks(anthropic_msg)[2]["type"], "image") # Supported image kept + self.assertEqual( + block_source(content_blocks(anthropic_msg)[2])["media_type"], "image/jpeg" + ) def test_multipart_with_tool_results_and_content(self): """Test conversion of PromptMessageExtended with both tool_results and content.""" @@ -313,17 +347,17 @@ def test_multipart_with_tool_results_and_content(self): # Assertions self.assertEqual(anthropic_msg["role"], "user") - self.assertEqual(len(anthropic_msg["content"]), 2) + self.assertEqual(len(content_blocks(anthropic_msg)), 2) # First block should be tool_result (must come first per Anthropic API) - self.assertEqual(anthropic_msg["content"][0]["type"], "tool_result") - self.assertEqual(anthropic_msg["content"][0]["tool_use_id"], "tool_id_1") - self.assertEqual(anthropic_msg["content"][0]["content"][0]["type"], "text") - self.assertEqual(anthropic_msg["content"][0]["content"][0]["text"], "Tool execution result") + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "tool_result") + self.assertEqual(content_blocks(anthropic_msg)[0]["tool_use_id"], "tool_id_1") + self.assertEqual(block_content(content_blocks(anthropic_msg)[0])[0]["type"], "text") + self.assertEqual(block_content(content_blocks(anthropic_msg)[0])[0]["text"], "Tool execution result") # Second block should be the additional content - self.assertEqual(anthropic_msg["content"][1]["type"], "text") - self.assertEqual(anthropic_msg["content"][1]["text"], "What should I do next?") + self.assertEqual(content_blocks(anthropic_msg)[1]["type"], "text") + self.assertEqual(content_blocks(anthropic_msg)[1]["text"], "What should I do next?") def test_code_file_as_text_document_with_filename(self): """Test handling of code files using a simple filename.""" @@ -342,9 +376,9 @@ def test_code_file_as_text_document_with_filename(self): anthropic_msg = AnthropicConverter.convert_to_anthropic(multipart) # Check that title is set correctly - self.assertEqual(anthropic_msg["content"][0]["title"], "example.py") - self.assertEqual(anthropic_msg["content"][0]["source"]["data"], code_text) - self.assertEqual(anthropic_msg["content"][0]["source"]["media_type"], "text/plain") + self.assertEqual(content_blocks(anthropic_msg)[0]["title"], "example.py") + self.assertEqual(block_source(content_blocks(anthropic_msg)[0])["data"], code_text) + self.assertEqual(block_source(content_blocks(anthropic_msg)[0])["media_type"], "text/plain") def test_code_file_as_text_document_with_uri(self): """Test handling of code files using a proper URI.""" @@ -365,15 +399,15 @@ def test_code_file_as_text_document_with_uri(self): anthropic_msg = AnthropicConverter.convert_to_anthropic(multipart) # Should extract just the filename from the path - self.assertEqual(anthropic_msg["content"][0]["title"], "example.py") - self.assertEqual(anthropic_msg["content"][0]["source"]["data"], code_text) + self.assertEqual(content_blocks(anthropic_msg)[0]["title"], "example.py") + self.assertEqual(block_source(content_blocks(anthropic_msg)[0])["data"], code_text) def test_unsupported_binary_resource_conversion(self): """Test handling of unsupported binary resource types.""" # Create an embedded resource with binary data binary_data = base64.b64encode(b"This is binary data").decode("utf-8") # 20 bytes of data binary_resource = BlobResourceContents( - uri="test://example.com/data.bin", + uri=AnyUrl("test://example.com/data.bin"), mimeType="application/octet-stream", blob=binary_data, ) @@ -384,11 +418,13 @@ def test_unsupported_binary_resource_conversion(self): anthropic_msg = AnthropicConverter.convert_to_anthropic(multipart) # Should have a fallback text block - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "text") + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "text") # Check that the content describes it as unsupported format - fallback_text = anthropic_msg["content"][0]["text"] + fallback_text = content_blocks(anthropic_msg)[0]["text"] + self.assertIsInstance(fallback_text, str) + assert isinstance(fallback_text, str) self.assertIn( "Embedded Resource test://example.com/data.bin with unsupported format application/octet-stream (28 characters)", fallback_text, @@ -419,20 +455,27 @@ def test_pdf_result_conversion(self): # Assertions self.assertEqual(anthropic_msg["role"], "user") # Now a single tool_result block that contains both text and the document - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "tool_result") - self.assertEqual(anthropic_msg["content"][0]["tool_use_id"], self.tool_use_id) - self.assertEqual(len(anthropic_msg["content"][0]["content"]), 2) + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "tool_result") + self.assertEqual(content_blocks(anthropic_msg)[0]["tool_use_id"], self.tool_use_id) + self.assertEqual(len(block_content(content_blocks(anthropic_msg)[0])), 2) # First inner block: text - self.assertEqual(anthropic_msg["content"][0]["content"][0]["type"], "text") - self.assertEqual(anthropic_msg["content"][0]["content"][0]["text"], self.sample_text) + self.assertEqual(block_content(content_blocks(anthropic_msg)[0])[0]["type"], "text") + self.assertEqual(block_content(content_blocks(anthropic_msg)[0])[0]["text"], self.sample_text) # Second inner block: document (PDF) - self.assertEqual(anthropic_msg["content"][0]["content"][1]["type"], "document") - self.assertEqual(anthropic_msg["content"][0]["content"][1]["source"]["type"], "base64") + self.assertEqual(block_content(content_blocks(anthropic_msg)[0])[1]["type"], "document") + self.assertEqual( + block_source(block_content(content_blocks(anthropic_msg)[0])[1])["type"], + "base64", + ) + self.assertEqual( + block_source(block_content(content_blocks(anthropic_msg)[0])[1])["media_type"], + "application/pdf", + ) self.assertEqual( - anthropic_msg["content"][0]["content"][1]["source"]["media_type"], "application/pdf" + block_source(block_content(content_blocks(anthropic_msg)[0])[1])["data"], + PDF_BASE64, ) - self.assertEqual(anthropic_msg["content"][0]["content"][1]["source"]["data"], PDF_BASE64) def test_binary_only_tool_result_conversion(self): """Binary-only tool result should be a single tool_result with a document inside.""" @@ -447,10 +490,10 @@ def test_binary_only_tool_result_conversion(self): # Should have a single tool_result block with the PDF document inside self.assertEqual(anthropic_msg["role"], "user") - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "tool_result") - self.assertEqual(len(anthropic_msg["content"][0]["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["content"][0]["type"], "document") + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "tool_result") + self.assertEqual(len(block_content(content_blocks(anthropic_msg)[0])), 1) + self.assertEqual(block_content(content_blocks(anthropic_msg)[0])[0]["type"], "document") def test_create_tool_results_message(self): """Test creation of user message with multiple tool results.""" @@ -475,18 +518,20 @@ def test_create_tool_results_message(self): # Assertions self.assertEqual(anthropic_msg["role"], "user") - self.assertEqual(len(anthropic_msg["content"]), 2) + self.assertEqual(len(content_blocks(anthropic_msg)), 2) # Check first tool result - self.assertEqual(anthropic_msg["content"][0]["type"], "tool_result") - self.assertEqual(anthropic_msg["content"][0]["tool_use_id"], tool_use_id1) - self.assertEqual(anthropic_msg["content"][0]["content"][0]["type"], "text") - self.assertEqual(anthropic_msg["content"][0]["content"][0]["text"], self.sample_text) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "tool_result") + self.assertEqual(content_blocks(anthropic_msg)[0]["tool_use_id"], tool_use_id1) + self.assertEqual(block_content(content_blocks(anthropic_msg)[0])[0]["type"], "text") + self.assertEqual(block_content(content_blocks(anthropic_msg)[0])[0]["text"], self.sample_text) # Check second tool result - self.assertEqual(anthropic_msg["content"][1]["type"], "tool_result") - self.assertEqual(anthropic_msg["content"][1]["tool_use_id"], tool_use_id2) - self.assertEqual(anthropic_msg["content"][1]["content"][0]["type"], "image") + self.assertEqual(content_blocks(anthropic_msg)[1]["type"], "tool_result") + self.assertEqual(content_blocks(anthropic_msg)[1]["tool_use_id"], tool_use_id2) + self.assertEqual( + block_content(content_blocks(anthropic_msg)[1])[0]["type"], "image" + ) def test_create_tool_results_message_with_error(self): """Test creation of tool results message with error flag.""" @@ -500,13 +545,13 @@ def test_create_tool_results_message_with_error(self): # Assertions self.assertEqual(anthropic_msg["role"], "user") - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "tool_result") - self.assertEqual(anthropic_msg["content"][0]["tool_use_id"], tool_use_id) - self.assertEqual(anthropic_msg["content"][0]["is_error"], True) - self.assertEqual(anthropic_msg["content"][0]["content"][0]["type"], "text") + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "tool_result") + self.assertEqual(content_blocks(anthropic_msg)[0]["tool_use_id"], tool_use_id) + self.assertEqual(content_blocks(anthropic_msg)[0]["is_error"], True) + self.assertEqual(block_content(content_blocks(anthropic_msg)[0])[0]["type"], "text") self.assertEqual( - anthropic_msg["content"][0]["content"][0]["text"], "Error: Something went wrong" + block_content(content_blocks(anthropic_msg)[0])[0]["text"], "Error: Something went wrong" ) def test_create_tool_results_message_with_empty_content(self): @@ -520,13 +565,13 @@ def test_create_tool_results_message_with_empty_content(self): # Should have a placeholder text block self.assertEqual(anthropic_msg["role"], "user") - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "tool_result") - self.assertEqual(anthropic_msg["content"][0]["tool_use_id"], tool_use_id) - self.assertEqual(len(anthropic_msg["content"][0]["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["content"][0]["type"], "text") + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "tool_result") + self.assertEqual(content_blocks(anthropic_msg)[0]["tool_use_id"], tool_use_id) + self.assertEqual(len(block_content(content_blocks(anthropic_msg)[0])), 1) + self.assertEqual(block_content(content_blocks(anthropic_msg)[0])[0]["type"], "text") self.assertEqual( - anthropic_msg["content"][0]["content"][0]["text"], "[No content in tool result]" + block_content(content_blocks(anthropic_msg)[0])[0]["text"], "[No content in tool result]" ) def test_create_tool_results_message_with_unsupported_image(self): @@ -545,14 +590,14 @@ def test_create_tool_results_message_with_unsupported_image(self): # Unsupported image should be converted to text self.assertEqual(anthropic_msg["role"], "user") - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "tool_result") - self.assertEqual(len(anthropic_msg["content"][0]["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["content"][0]["type"], "text") - self.assertIn( - "Image with unsupported format 'image/bmp'", - anthropic_msg["content"][0]["content"][0]["text"], - ) + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "tool_result") + self.assertEqual(len(block_content(content_blocks(anthropic_msg)[0])), 1) + self.assertEqual(block_content(content_blocks(anthropic_msg)[0])[0]["type"], "text") + unsupported_text = block_content(content_blocks(anthropic_msg)[0])[0]["text"] + self.assertIsInstance(unsupported_text, str) + assert isinstance(unsupported_text, str) + self.assertIn("Image with unsupported format 'image/bmp'", unsupported_text) def test_create_tool_results_message_with_mixed_content(self): """Test creation of tool results message with mixed text and image content.""" @@ -569,13 +614,13 @@ def test_create_tool_results_message_with_mixed_content(self): # Assertions self.assertEqual(anthropic_msg["role"], "user") - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "tool_result") - self.assertEqual(anthropic_msg["content"][0]["tool_use_id"], tool_use_id) - self.assertEqual(len(anthropic_msg["content"][0]["content"]), 2) - self.assertEqual(anthropic_msg["content"][0]["content"][0]["type"], "text") - self.assertEqual(anthropic_msg["content"][0]["content"][0]["text"], self.sample_text) - self.assertEqual(anthropic_msg["content"][0]["content"][1]["type"], "image") + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "tool_result") + self.assertEqual(content_blocks(anthropic_msg)[0]["tool_use_id"], tool_use_id) + self.assertEqual(len(block_content(content_blocks(anthropic_msg)[0])), 2) + self.assertEqual(block_content(content_blocks(anthropic_msg)[0])[0]["type"], "text") + self.assertEqual(block_content(content_blocks(anthropic_msg)[0])[0]["text"], self.sample_text) + self.assertEqual(block_content(content_blocks(anthropic_msg)[0])[1]["type"], "image") def test_create_tool_results_message_with_text_resource(self): """Test creation of tool results message with text resource (markdown).""" @@ -595,16 +640,16 @@ def test_create_tool_results_message_with_text_resource(self): # Text resources should be included in tool result as text blocks self.assertEqual(anthropic_msg["role"], "user") - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "tool_result") - self.assertEqual(anthropic_msg["content"][0]["tool_use_id"], tool_use_id) - self.assertEqual(len(anthropic_msg["content"][0]["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["content"][0]["type"], "text") - self.assertEqual(anthropic_msg["content"][0]["content"][0]["text"], "markdown text") + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "tool_result") + self.assertEqual(content_blocks(anthropic_msg)[0]["tool_use_id"], tool_use_id) + self.assertEqual(len(block_content(content_blocks(anthropic_msg)[0])), 1) + self.assertEqual(block_content(content_blocks(anthropic_msg)[0])[0]["type"], "text") + self.assertEqual(block_content(content_blocks(anthropic_msg)[0])[0]["text"], "markdown text") def create_text_resource( - text: str, filename_or_uri: str, mime_type: str = None + text: str, filename_or_uri: str, mime_type: str | None = None ) -> TextResourceContents: """ Helper function to create a TextResourceContents with proper URI handling. @@ -620,7 +665,7 @@ def create_text_resource( # Normalize the URI uri = normalize_uri(filename_or_uri) - return TextResourceContents(uri=uri, mimeType=mime_type, text=text) + return TextResourceContents(uri=AnyUrl(uri), mimeType=mime_type, text=text) class TestAnthropicAssistantConverter(unittest.TestCase): @@ -641,9 +686,9 @@ def test_assistant_text_content_conversion(self): # Assertions self.assertEqual(anthropic_msg["role"], "assistant") - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "text") - self.assertEqual(anthropic_msg["content"][0]["text"], self.sample_text) + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "text") + self.assertEqual(content_blocks(anthropic_msg)[0]["text"], self.sample_text) def test_convert_prompt_message_to_anthropic(self): """Test conversion of a standard PromptMessage to Anthropic format.""" @@ -656,9 +701,9 @@ def test_convert_prompt_message_to_anthropic(self): # Assertions self.assertEqual(anthropic_msg["role"], "assistant") - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "text") - self.assertEqual(anthropic_msg["content"][0]["text"], self.sample_text) + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "text") + self.assertEqual(content_blocks(anthropic_msg)[0]["text"], self.sample_text) def test_convert_prompt_message_image_to_anthropic(self): """Test conversion of a PromptMessage with image content to Anthropic format.""" @@ -672,17 +717,17 @@ def test_convert_prompt_message_image_to_anthropic(self): # Assertions self.assertEqual(anthropic_msg["role"], "user") - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "image") - self.assertEqual(anthropic_msg["content"][0]["source"]["type"], "base64") - self.assertEqual(anthropic_msg["content"][0]["source"]["media_type"], "image/jpeg") - self.assertEqual(anthropic_msg["content"][0]["source"]["data"], image_base64) + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "image") + self.assertEqual(block_source(content_blocks(anthropic_msg)[0])["type"], "base64") + self.assertEqual(block_source(content_blocks(anthropic_msg)[0])["media_type"], "image/jpeg") + self.assertEqual(block_source(content_blocks(anthropic_msg)[0])["data"], image_base64) def test_convert_prompt_message_embedded_resource_to_anthropic(self): """Test conversion of a PromptMessage with embedded resource to Anthropic format.""" # Create a PromptMessage with embedded text resource text_resource = TextResourceContents( - uri="test://example.com/document.txt", + uri=AnyUrl("test://example.com/document.txt"), mimeType="text/plain", text="This is a text resource", ) @@ -694,11 +739,11 @@ def test_convert_prompt_message_embedded_resource_to_anthropic(self): # Assertions self.assertEqual(anthropic_msg["role"], "user") - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "document") - self.assertEqual(anthropic_msg["content"][0]["source"]["type"], "text") - self.assertEqual(anthropic_msg["content"][0]["title"], "document.txt") - self.assertEqual(anthropic_msg["content"][0]["source"]["data"], "This is a text resource") + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "document") + self.assertEqual(block_source(content_blocks(anthropic_msg)[0])["type"], "text") + self.assertEqual(content_blocks(anthropic_msg)[0]["title"], "document.txt") + self.assertEqual(block_source(content_blocks(anthropic_msg)[0])["data"], "This is a text resource") def test_assistant_multiple_text_blocks(self): """Test conversion of assistant messages with multiple text blocks.""" @@ -713,11 +758,11 @@ def test_assistant_multiple_text_blocks(self): # Assertions self.assertEqual(anthropic_msg["role"], "assistant") - self.assertEqual(len(anthropic_msg["content"]), 2) - self.assertEqual(anthropic_msg["content"][0]["type"], "text") - self.assertEqual(anthropic_msg["content"][0]["text"], "First part of response") - self.assertEqual(anthropic_msg["content"][1]["type"], "text") - self.assertEqual(anthropic_msg["content"][1]["text"], "Second part of response") + self.assertEqual(len(content_blocks(anthropic_msg)), 2) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "text") + self.assertEqual(content_blocks(anthropic_msg)[0]["text"], "First part of response") + self.assertEqual(content_blocks(anthropic_msg)[1]["type"], "text") + self.assertEqual(content_blocks(anthropic_msg)[1]["text"], "Second part of response") def test_assistant_thinking_blocks_deserialized_from_channel(self): """Ensure thinking channel JSON is converted to Anthropic thinking params.""" @@ -746,14 +791,14 @@ def test_assistant_thinking_blocks_deserialized_from_channel(self): anthropic_msg = AnthropicConverter.convert_to_anthropic(multipart) self.assertEqual(anthropic_msg["role"], "assistant") - self.assertEqual(len(anthropic_msg["content"]), 3) - self.assertEqual(anthropic_msg["content"][0]["type"], "thinking") - self.assertEqual(anthropic_msg["content"][0]["thinking"], "Reasoning summary.") - self.assertEqual(anthropic_msg["content"][0]["signature"], "sig123") - self.assertEqual(anthropic_msg["content"][1]["type"], "redacted_thinking") - self.assertEqual(anthropic_msg["content"][1]["data"], "opaque") - self.assertEqual(anthropic_msg["content"][2]["type"], "tool_use") - self.assertEqual(anthropic_msg["content"][2]["name"], "test_tool") + self.assertEqual(len(content_blocks(anthropic_msg)), 3) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "thinking") + self.assertEqual(content_blocks(anthropic_msg)[0]["thinking"], "Reasoning summary.") + self.assertEqual(content_blocks(anthropic_msg)[0]["signature"], "sig123") + self.assertEqual(content_blocks(anthropic_msg)[1]["type"], "redacted_thinking") + self.assertEqual(content_blocks(anthropic_msg)[1]["data"], "opaque") + self.assertEqual(content_blocks(anthropic_msg)[2]["type"], "tool_use") + self.assertEqual(content_blocks(anthropic_msg)[2]["name"], "test_tool") def test_assistant_non_text_content_stripped(self): """Test that non-text content is stripped from assistant messages.""" @@ -772,9 +817,9 @@ def test_assistant_non_text_content_stripped(self): # Only text should remain, image should be filtered out self.assertEqual(anthropic_msg["role"], "assistant") - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "text") - self.assertEqual(anthropic_msg["content"][0]["text"], self.sample_text) + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "text") + self.assertEqual(content_blocks(anthropic_msg)[0]["text"], self.sample_text) def test_assistant_embedded_resource_stripped(self): """Test that embedded resources are stripped from assistant messages.""" @@ -782,7 +827,7 @@ def test_assistant_embedded_resource_stripped(self): text_content = TextContent(type="text", text=self.sample_text) resource_content = TextResourceContents( - uri="test://example.com/document.txt", + uri=AnyUrl("test://example.com/document.txt"), mimeType="text/plain", text="Some document content", ) @@ -797,9 +842,9 @@ def test_assistant_embedded_resource_stripped(self): # Only text should remain, resource should be filtered out self.assertEqual(anthropic_msg["role"], "assistant") - self.assertEqual(len(anthropic_msg["content"]), 1) - self.assertEqual(anthropic_msg["content"][0]["type"], "text") - self.assertEqual(anthropic_msg["content"][0]["text"], self.sample_text) + self.assertEqual(len(content_blocks(anthropic_msg)), 1) + self.assertEqual(content_blocks(anthropic_msg)[0]["type"], "text") + self.assertEqual(content_blocks(anthropic_msg)[0]["text"], self.sample_text) def test_assistant_empty_content(self): """Test conversion with empty content from assistant.""" @@ -810,4 +855,4 @@ def test_assistant_empty_content(self): # Should have empty content list self.assertEqual(anthropic_msg["role"], "assistant") - self.assertEqual(len(anthropic_msg["content"]), 0) + self.assertEqual(len(content_blocks(anthropic_msg)), 0) diff --git a/tests/unit/fast_agent/llm/providers/test_multipart_converter_google.py b/tests/unit/fast_agent/llm/providers/test_multipart_converter_google.py index a9fcbad85..ae3374838 100644 --- a/tests/unit/fast_agent/llm/providers/test_multipart_converter_google.py +++ b/tests/unit/fast_agent/llm/providers/test_multipart_converter_google.py @@ -37,7 +37,13 @@ def test_tool_result_conversion(self): ) assert 1 == len(converted) assert "tool" == converted[0].role - assert self.sample_text == converted[0].parts[0].function_response.response["text_content"] + parts = converted[0].parts + assert parts is not None + fn_resp = parts[0].function_response + assert fn_resp is not None + response = fn_resp.response + assert isinstance(response, dict) + assert self.sample_text == response["text_content"] def test_multiple_tool_results_with_mixed_content(self): """Test conversion of multiple tool results with different content types.""" @@ -66,8 +72,16 @@ def test_multiple_tool_results_with_mixed_content(self): # Assertions assert 2 == len(converted) - assert 1 == len(converted[0].parts) # Text Only - assert 2 == len(converted[1].parts[0].function_response.response) # Text and Image + first_parts = converted[0].parts + assert first_parts is not None + assert 1 == len(first_parts) # Text Only + second_parts = converted[1].parts + assert second_parts is not None + second_fn_resp = second_parts[0].function_response + assert second_fn_resp is not None + second_response = second_fn_resp.response + assert isinstance(second_response, dict) + assert 2 == len(second_response) # Text and Image # assert self.sample_text == converted[0].parts[0].function_response.response["text"][0] diff --git a/tests/unit/fast_agent/llm/providers/test_multipart_converter_openai.py b/tests/unit/fast_agent/llm/providers/test_multipart_converter_openai.py index 49275c2cc..b9859fa1e 100644 --- a/tests/unit/fast_agent/llm/providers/test_multipart_converter_openai.py +++ b/tests/unit/fast_agent/llm/providers/test_multipart_converter_openai.py @@ -1,5 +1,7 @@ import base64 import unittest +from collections.abc import Iterable, Mapping +from typing import TYPE_CHECKING, cast from mcp.types import ( BlobResourceContents, @@ -20,6 +22,40 @@ from fast_agent.llm.provider_types import Provider from fast_agent.mcp.prompt_message_extended import PromptMessageExtended +if TYPE_CHECKING: + from openai.types.chat import ChatCompletionToolMessageParam, ChatCompletionUserMessageParam + + +def content_parts(message: Mapping[str, object]) -> list[dict[str, object]]: + content = message.get("content", []) + if isinstance(content, str): + return [] + if not isinstance(content, Iterable): + return [] + filtered = [part for part in content if isinstance(part, dict)] + return cast("list[dict[str, object]]", filtered) + + +def text_part(message: Mapping[str, object], index: int = 0) -> str: + part = content_parts(message)[index] + text = part.get("text") + assert isinstance(text, str) + return text + + +def image_url_part(message: Mapping[str, object], index: int = 0) -> dict[str, object]: + part = content_parts(message)[index] + image_url = part.get("image_url") + assert isinstance(image_url, dict) + return cast("dict[str, object]", image_url) + + +def file_part(message: Mapping[str, object], index: int = 0) -> dict[str, object]: + part = content_parts(message)[index] + file_obj = part.get("file") + assert isinstance(file_obj, dict) + return cast("dict[str, object]", file_obj) + class TestOpenAIUserConverter(unittest.TestCase): """Test cases for conversion from user role MCP message types to OpenAI API.""" @@ -59,10 +95,10 @@ def test_image_content_conversion(self): # Assertions self.assertEqual(openai_msg["role"], "user") - self.assertEqual(len(openai_msg["content"]), 1) - self.assertEqual(openai_msg["content"][0]["type"], "image_url") + self.assertEqual(len(content_parts(openai_msg)), 1) + self.assertEqual(content_parts(openai_msg)[0]["type"], "image_url") self.assertEqual( - openai_msg["content"][0]["image_url"]["url"], + image_url_part(openai_msg)["url"], f"data:image/jpeg;base64,{self.sample_image_base64}", ) @@ -70,7 +106,7 @@ def test_embedded_resource_text_conversion(self): """Test conversion of text-based EmbeddedResource to OpenAI text content with fastagent:file tags.""" # Create a text resource text_resource = TextResourceContents( - uri="test://example.com/document.txt", + uri=AnyUrl("test://example.com/document.txt"), mimeType="text/plain", text=self.sample_text, ) @@ -84,20 +120,20 @@ def test_embedded_resource_text_conversion(self): # Assertions self.assertEqual(openai_msg["role"], "user") - self.assertEqual(len(openai_msg["content"]), 1) - self.assertEqual(openai_msg["content"][0]["type"], "text") - self.assertIn("", openai_msg["content"][0]["text"]) + self.assertEqual(len(content_parts(openai_msg)), 1) + self.assertEqual(content_parts(openai_msg)[0]["type"], "text") + self.assertIn("", text_part(openai_msg)) def test_embedded_resource_pdf_conversion(self): """Test conversion of PDF EmbeddedResource to OpenAI file part.""" # Create a PDF resource pdf_base64 = base64.b64encode(b"fake_pdf_data").decode("utf-8") pdf_resource = BlobResourceContents( - uri="test://example.com/document.pdf", + uri=AnyUrl("test://example.com/document.pdf"), mimeType="application/pdf", blob=pdf_base64, ) @@ -111,11 +147,11 @@ def test_embedded_resource_pdf_conversion(self): # Assertions self.assertEqual(openai_msg["role"], "user") - self.assertEqual(len(openai_msg["content"]), 1) - self.assertEqual(openai_msg["content"][0]["type"], "file") - self.assertEqual(openai_msg["content"][0]["file"]["filename"], "document.pdf") + self.assertEqual(len(content_parts(openai_msg)), 1) + self.assertEqual(content_parts(openai_msg)[0]["type"], "file") + self.assertEqual(file_part(openai_msg)["filename"], "document.pdf") self.assertEqual( - openai_msg["content"][0]["file"]["file_data"], + file_part(openai_msg)["file_data"], f"data:application/pdf;base64,{pdf_base64}", ) @@ -123,7 +159,7 @@ def test_embedded_resource_image_url_conversion(self): """Test conversion of image URL in EmbeddedResource to OpenAI image block.""" # Create an image resource with URL image_resource = BlobResourceContents( - uri="https://example.com/image.jpg", + uri=AnyUrl("https://example.com/image.jpg"), mimeType="image/jpeg", blob=self.sample_image_base64, # This would be ignored for URL in OpenAI ) @@ -137,10 +173,10 @@ def test_embedded_resource_image_url_conversion(self): # Assertions self.assertEqual(openai_msg["role"], "user") - self.assertEqual(len(openai_msg["content"]), 1) - self.assertEqual(openai_msg["content"][0]["type"], "image_url") + self.assertEqual(len(content_parts(openai_msg)), 1) + self.assertEqual(content_parts(openai_msg)[0]["type"], "image_url") self.assertEqual( - openai_msg["content"][0]["image_url"]["url"], + image_url_part(openai_msg)["url"], "https://example.com/image.jpg", ) @@ -163,12 +199,12 @@ def test_linked_resource_conversion(self): # Assertions self.assertEqual(openai_msg["role"], "user") - self.assertEqual(len(openai_msg["content"]), 1) - self.assertEqual(openai_msg["content"][0]["type"], "text") - self.assertIn("Some description", openai_msg["content"][0]["text"]) - self.assertIn("some name", openai_msg["content"][0]["text"]) - self.assertIn("text/plain", openai_msg["content"][0]["text"]) - self.assertIn("test://example.com/document.txt", openai_msg["content"][0]["text"]) + self.assertEqual(len(content_parts(openai_msg)), 1) + self.assertEqual(content_parts(openai_msg)[0]["type"], "text") + self.assertIn("Some description", text_part(openai_msg)) + self.assertIn("some name", text_part(openai_msg)) + self.assertIn("text/plain", text_part(openai_msg)) + self.assertIn("test://example.com/document.txt", text_part(openai_msg)) def test_multiple_content_blocks(self): """Test conversion of messages with multiple content blocks.""" @@ -190,19 +226,19 @@ def test_multiple_content_blocks(self): # Assertions self.assertEqual(openai_msg["role"], "user") - self.assertEqual(len(openai_msg["content"]), 3) - self.assertEqual(openai_msg["content"][0]["type"], "text") - self.assertEqual(openai_msg["content"][0]["text"], "First text") - self.assertEqual(openai_msg["content"][1]["type"], "image_url") - self.assertEqual(openai_msg["content"][2]["type"], "text") - self.assertEqual(openai_msg["content"][2]["text"], "Second text") + self.assertEqual(len(content_parts(openai_msg)), 3) + self.assertEqual(content_parts(openai_msg)[0]["type"], "text") + self.assertEqual(text_part(openai_msg), "First text") + self.assertEqual(content_parts(openai_msg)[1]["type"], "image_url") + self.assertEqual(content_parts(openai_msg)[2]["type"], "text") + self.assertEqual(text_part(openai_msg, 2), "Second text") def test_svg_resource_conversion(self): """Test handling of SVG resources - should convert to text with fastagent:file tags for OpenAI.""" # Create an embedded SVG resource svg_content = '' svg_resource = TextResourceContents( - uri="test://example.com/image.svg", + uri=AnyUrl("test://example.com/image.svg"), mimeType="image/svg+xml", text=svg_content, ) @@ -215,13 +251,13 @@ def test_svg_resource_conversion(self): openai_msg = openai_msgs[0] # Should be converted to a text block with the SVG in fastagent:file tags - self.assertEqual(len(openai_msg["content"]), 1) - self.assertEqual(openai_msg["content"][0]["type"], "text") - self.assertIn("", openai_msg["content"][0]["text"]) + self.assertEqual(len(content_parts(openai_msg)), 1) + self.assertEqual(content_parts(openai_msg)[0]["type"], "text") + self.assertIn("", text_part(openai_msg)) def test_empty_content_list(self): """Test conversion with empty content list.""" @@ -242,7 +278,7 @@ def test_code_file_conversion(self): # Create a code resource code_resource = TextResourceContents( - uri="test://example.com/example.py", + uri=AnyUrl("test://example.com/example.py"), mimeType="text/x-python", text=code_text, ) @@ -256,13 +292,13 @@ def test_code_file_conversion(self): openai_msg = openai_msgs[0] # Check that proper fastagent:file tags are used - self.assertEqual(len(openai_msg["content"]), 1) - self.assertEqual(openai_msg["content"][0]["type"], "text") - self.assertIn("", openai_msg["content"][0]["text"]) + self.assertEqual(len(content_parts(openai_msg)), 1) + self.assertEqual(content_parts(openai_msg)[0]["type"], "text") + self.assertIn("", text_part(openai_msg)) class TestOpenAIAssistantConverter(unittest.TestCase): @@ -326,10 +362,10 @@ def test_convert_prompt_message_to_openai_user_image(self): # Assertions self.assertEqual(openai_msg["role"], "user") self.assertIsInstance(openai_msg["content"], list) - self.assertEqual(len(openai_msg["content"]), 1) - self.assertEqual(openai_msg["content"][0]["type"], "image_url") + self.assertEqual(len(content_parts(openai_msg)), 1) + self.assertEqual(content_parts(openai_msg)[0]["type"], "image_url") self.assertEqual( - openai_msg["content"][0]["image_url"]["url"], + image_url_part(openai_msg)["url"], f"data:image/jpeg;base64,{image_base64}", ) @@ -337,7 +373,7 @@ def test_convert_prompt_message_embedded_resource_to_openai(self): """Test conversion of a PromptMessage with embedded resource to OpenAI format.""" # Create a PromptMessage with embedded text resource text_resource = TextResourceContents( - uri="test://example.com/document.txt", + uri=AnyUrl("test://example.com/document.txt"), mimeType="text/plain", text="This is a text resource", ) @@ -350,10 +386,10 @@ def test_convert_prompt_message_embedded_resource_to_openai(self): # Assertions self.assertEqual(openai_msg["role"], "user") self.assertIsInstance(openai_msg["content"], list) - self.assertEqual(len(openai_msg["content"]), 1) - self.assertEqual(openai_msg["content"][0]["type"], "text") - self.assertIn(" dict[str, Any]: + assert isinstance(value, dict) + return cast("dict[str, Any]", value) + + +def _as_content_list(value: object) -> list[dict[str, Any]]: + assert isinstance(value, list) + for item in value: + assert isinstance(item, dict) + return cast("list[dict[str, Any]]", value) + + def apply_cache_control_to_message(message: dict[str, Any], position: int) -> bool: """ Apply cache control to a message at the specified position. @@ -44,8 +56,9 @@ def test_apply_cache_control_to_valid_message(self): success = apply_cache_control_to_message(message, 0) self.assertTrue(success) - self.assertIn("cache_control", message["content"][0]) - self.assertEqual(message["content"][0]["cache_control"]["type"], "ephemeral") + content_list = _as_content_list(message["content"]) + self.assertIn("cache_control", content_list[0]) + self.assertEqual(content_list[0]["cache_control"]["type"], "ephemeral") def test_apply_cache_control_to_message_with_multiple_blocks(self): """Test applying cache control to message with multiple content blocks.""" @@ -61,9 +74,10 @@ def test_apply_cache_control_to_message_with_multiple_blocks(self): self.assertTrue(success) # Should apply to the last block - self.assertNotIn("cache_control", message["content"][0]) - self.assertIn("cache_control", message["content"][1]) - self.assertEqual(message["content"][1]["cache_control"]["type"], "ephemeral") + content_list = _as_content_list(message["content"]) + self.assertNotIn("cache_control", content_list[0]) + self.assertIn("cache_control", content_list[1]) + self.assertEqual(content_list[1]["cache_control"]["type"], "ephemeral") def test_apply_cache_control_to_invalid_message(self): """Test that cache control is not applied to invalid messages.""" @@ -91,10 +105,11 @@ def test_apply_cache_control_preserves_existing_content(self): self.assertTrue(success) # Original content should be preserved - self.assertEqual(message["content"][0]["text"], original_text) - self.assertEqual(message["content"][0]["type"], "text") + content_list = _as_content_list(message["content"]) + self.assertEqual(content_list[0]["text"], original_text) + self.assertEqual(content_list[0]["type"], "text") # Cache control should be added - self.assertIn("cache_control", message["content"][0]) + self.assertIn("cache_control", content_list[0]) def test_multipart_to_anthropic_conversion_preserves_structure(self): """Test that PromptMessageExtended -> Anthropic conversion preserves structure for caching.""" @@ -107,14 +122,16 @@ def test_multipart_to_anthropic_conversion_preserves_structure(self): # Verify structure is suitable for cache control application self.assertEqual(anthropic_msg["role"], "user") - self.assertIsInstance(anthropic_msg["content"], list) - self.assertTrue(len(anthropic_msg["content"]) > 0) - self.assertIsInstance(anthropic_msg["content"][0], dict) + content_list = list(anthropic_msg.get("content", [])) + self.assertTrue(len(content_list) > 0) + self.assertIsInstance(content_list[0], dict) # Apply cache control to the converted message - success = apply_cache_control_to_message(anthropic_msg, 0) + msg = _as_dict(anthropic_msg) + success = apply_cache_control_to_message(msg, 0) self.assertTrue(success) - self.assertIn("cache_control", anthropic_msg["content"][0]) + content_list = _as_content_list(msg["content"]) + self.assertIn("cache_control", content_list[0]) def test_multiple_multipart_messages_conversion_and_caching(self): """Test converting multiple multipart messages and applying cache to specific position.""" @@ -130,17 +147,21 @@ def test_multiple_multipart_messages_conversion_and_caching(self): # Apply cache control to position 2 (3rd message) cache_position = 2 - success = apply_cache_control_to_message(messages[cache_position], cache_position) + msg = messages[cache_position] + assert isinstance(msg, dict) + success = apply_cache_control_to_message(msg, cache_position) self.assertTrue(success) # Verify only the target message has cache control for i, msg in enumerate(messages): if i == cache_position: - self.assertIn("cache_control", msg["content"][0]) - self.assertEqual(msg["content"][0]["cache_control"]["type"], "ephemeral") + content_list = _as_content_list(msg["content"]) + self.assertIn("cache_control", content_list[0]) + self.assertEqual(content_list[0]["cache_control"]["type"], "ephemeral") else: - self.assertNotIn("cache_control", msg["content"][0]) + content_list = _as_content_list(msg["content"]) + self.assertNotIn("cache_control", content_list[0]) def test_assistant_message_caching(self): """Test that assistant messages can also receive cache control.""" @@ -148,11 +169,13 @@ def test_assistant_message_caching(self): multipart = PromptMessageExtended(role="assistant", content=[text_content]) anthropic_msg = AnthropicConverter.convert_to_anthropic(multipart) - success = apply_cache_control_to_message(anthropic_msg, 0) + msg = _as_dict(anthropic_msg) + success = apply_cache_control_to_message(msg, 0) self.assertTrue(success) self.assertEqual(anthropic_msg["role"], "assistant") - self.assertIn("cache_control", anthropic_msg["content"][0]) + content_list = _as_content_list(msg["content"]) + self.assertIn("cache_control", content_list[0]) def test_cache_control_application_idempotent(self): """Test that applying cache control multiple times doesn't break anything.""" @@ -166,8 +189,9 @@ def test_cache_control_application_idempotent(self): self.assertTrue(success2) # Should still have cache control - self.assertIn("cache_control", message["content"][0]) - self.assertEqual(message["content"][0]["cache_control"]["type"], "ephemeral") + content_list = _as_content_list(message["content"]) + self.assertIn("cache_control", content_list[0]) + self.assertEqual(content_list[0]["cache_control"]["type"], "ephemeral") if __name__ == "__main__": diff --git a/tests/unit/fast_agent/llm/test_max_tokens_acp_regression.py b/tests/unit/fast_agent/llm/test_max_tokens_acp_regression.py index 85d21ad9c..b53d5211c 100644 --- a/tests/unit/fast_agent/llm/test_max_tokens_acp_regression.py +++ b/tests/unit/fast_agent/llm/test_max_tokens_acp_regression.py @@ -6,17 +6,25 @@ The merge logic should preserve the model-aware maxTokens from default_request_params. """ +from typing import TypeGuard + import pytest from fast_agent.agents.agent_types import AgentConfig from fast_agent.agents.llm_agent import LlmAgent from fast_agent.config import HuggingFaceSettings, Settings from fast_agent.context import Context +from fast_agent.interfaces import FastAgentLLMProtocol +from fast_agent.llm.fastagent_llm import FastAgentLLM from fast_agent.llm.model_database import ModelDatabase from fast_agent.llm.provider.openai.llm_huggingface import HuggingFaceLLM from fast_agent.types import RequestParams +def _is_fastagent_llm(value: FastAgentLLMProtocol) -> TypeGuard[FastAgentLLM]: + return isinstance(value, FastAgentLLM) + + class TestModelDatabaseLookup: """Test that ModelDatabase correctly looks up the kimi model.""" @@ -202,6 +210,7 @@ async def test_attach_llm_via_factory_gets_correct_max_tokens(self): factory = ModelFactory.create_factory("hf.moonshotai/kimi-k2-instruct-0905") llm = await agent.attach_llm(factory) + assert _is_fastagent_llm(llm) assert llm.default_request_params.maxTokens == 16384, ( f"Expected 16384, got {llm.default_request_params.maxTokens}" @@ -218,6 +227,7 @@ async def test_kimi_alias_via_factory_gets_correct_max_tokens(self): factory = ModelFactory.create_factory("kimi") llm = await agent.attach_llm(factory) + assert _is_fastagent_llm(llm) # kimi alias resolves to hf.moonshotai/Kimi-K2-Instruct-0905:groq # which should get maxTokens=16384 from KIMI_MOONSHOT @@ -235,6 +245,7 @@ async def test_kimithink_alias_via_factory_gets_correct_max_tokens(self): factory = ModelFactory.create_factory("kimithink") llm = await agent.attach_llm(factory) + assert _is_fastagent_llm(llm) # kimithink alias resolves to hf.moonshotai/Kimi-K2-Thinking:together # which should get maxTokens=16384 from KIMI_MOONSHOT_THINKING @@ -268,6 +279,7 @@ async def test_apply_model_flow_preserves_model_max_tokens(self): # Create the factory and attach LLM with the params factory = ModelFactory.create_factory("kimi") llm = await agent.attach_llm(factory, request_params=recreated_params) + assert _is_fastagent_llm(llm) # maxTokens should be 16384 from the new model's ModelDatabase entry assert llm.default_request_params.maxTokens == 16384, ( @@ -290,6 +302,7 @@ async def test_attach_llm_then_acp_merge_preserves_max_tokens(self): factory = ModelFactory.create_factory("hf.moonshotai/kimi-k2-instruct-0905") llm = await agent.attach_llm(factory) + assert _is_fastagent_llm(llm) # Verify LLM was created with correct maxTokens assert llm.default_request_params.maxTokens == 16384, ( @@ -322,6 +335,7 @@ async def test_colon_syntax_bug_model_name_corrupted(self): factory = ModelFactory.create_factory("hf:moonshotai/kimi-k2-instruct-0905") llm = await agent.attach_llm(factory) + assert _is_fastagent_llm(llm) # This documents the current buggy behavior # The model name gets corrupted with a leading colon diff --git a/tests/unit/fast_agent/llm/test_model_database.py b/tests/unit/fast_agent/llm/test_model_database.py index d940c89b7..d9af03c9a 100644 --- a/tests/unit/fast_agent/llm/test_model_database.py +++ b/tests/unit/fast_agent/llm/test_model_database.py @@ -1,11 +1,14 @@ + from fast_agent.agents.agent_types import AgentConfig from fast_agent.agents.llm_agent import LlmAgent from fast_agent.config import HuggingFaceSettings, Settings from fast_agent.constants import DEFAULT_MAX_ITERATIONS from fast_agent.context import Context +from fast_agent.llm.fastagent_llm import FastAgentLLM from fast_agent.llm.model_database import ModelDatabase from fast_agent.llm.model_factory import ModelFactory from fast_agent.llm.provider.openai.llm_huggingface import HuggingFaceLLM +from fast_agent.llm.provider.openai.llm_openai import OpenAILLM def test_model_database_context_windows(): @@ -28,13 +31,14 @@ def test_model_database_max_tokens(): # Test fallbacks assert ModelDatabase.get_default_max_tokens("unknown-model") == 2048 - assert ModelDatabase.get_default_max_tokens(None) == 2048 + assert ModelDatabase.get_default_max_tokens("") == 2048 def test_model_database_tokenizes(): """Test that ModelDatabase returns expected tokenization types""" # Test multimodal model claude_tokenizes = ModelDatabase.get_tokenizes("claude-sonnet-4-0") + assert claude_tokenizes is not None assert "text/plain" in claude_tokenizes assert "image/jpeg" in claude_tokenizes assert "application/pdf" in claude_tokenizes @@ -93,16 +97,19 @@ def test_llm_uses_model_database_for_max_tokens(): # Test with a model that has 8192 max_output_tokens (should get full amount) factory = ModelFactory.create_factory("claude-sonnet-4-0") llm = factory(agent=agent) + assert isinstance(llm, FastAgentLLM) assert llm.default_request_params.maxTokens == 64000 # Test with a model that has high max_output_tokens (should get full amount) factory2 = ModelFactory.create_factory("o1") llm2 = factory2(agent=agent) + assert isinstance(llm2, FastAgentLLM) assert llm2.default_request_params.maxTokens == 100000 # Test with passthrough model (should get its configured max tokens) factory3 = ModelFactory.create_factory("passthrough") llm3 = factory3(agent=agent) + assert isinstance(llm3, FastAgentLLM) expected_max_tokens = ModelDatabase.get_default_max_tokens("passthrough") assert llm3.default_request_params.maxTokens == expected_max_tokens @@ -112,16 +119,19 @@ def test_llm_usage_tracking_uses_model_database(): factory = ModelFactory.create_factory("passthrough") agent = LlmAgent(AgentConfig(name="Test Agent")) llm = factory(agent=agent, model="claude-sonnet-4-0") + assert isinstance(llm, FastAgentLLM) # The usage_accumulator should be able to get context window from ModelDatabase # when it has a model set (this happens when turns are added) - llm.usage_accumulator.model = "claude-sonnet-4-0" - assert llm.usage_accumulator.context_window_size == 200000 + usage_accumulator = llm.usage_accumulator + assert usage_accumulator is not None + usage_accumulator.model = "claude-sonnet-4-0" + assert usage_accumulator.context_window_size == 200000 assert llm.default_request_params.maxTokens == 64000 # Should match ModelDatabase default # Test with unknown model - llm.usage_accumulator.model = "unknown-model" - assert llm.usage_accumulator.context_window_size is None + usage_accumulator.model = "unknown-model" + assert usage_accumulator.context_window_size is None def test_openai_provider_preserves_all_settings(): @@ -130,6 +140,7 @@ def test_openai_provider_preserves_all_settings(): agent = LlmAgent(AgentConfig(name="Test Agent")) llm = factory(agent=agent, instruction="You are a helpful assistant") + assert isinstance(llm, FastAgentLLM) # Verify all the original OpenAI settings are preserved params = llm.default_request_params @@ -165,6 +176,7 @@ def test_openai_llm_normalizes_repeated_roles(): agent = LlmAgent(AgentConfig(name="Test Agent")) factory = ModelFactory.create_factory("gpt-4o") llm = factory(agent=agent) + assert isinstance(llm, OpenAILLM) assert llm._normalize_role("assistantassistant") == "assistant" assert llm._normalize_role("assistantASSISTANTassistant") == "assistant" @@ -177,10 +189,12 @@ def test_openai_llm_uses_model_database_reasoning_flag(): agent = LlmAgent(AgentConfig(name="Test Agent")) reasoning_llm = ModelFactory.create_factory("o1")(agent=agent) + assert isinstance(reasoning_llm, OpenAILLM) assert reasoning_llm._reasoning assert getattr(reasoning_llm, "_reasoning_mode", None) == "openai" standard_llm = ModelFactory.create_factory("gpt-4o")(agent=agent) + assert isinstance(standard_llm, OpenAILLM) assert not standard_llm._reasoning assert getattr(standard_llm, "_reasoning_mode", None) is None diff --git a/tests/unit/fast_agent/llm/test_model_info_caps.py b/tests/unit/fast_agent/llm/test_model_info_caps.py index f7f9b8bc9..6198ca986 100644 --- a/tests/unit/fast_agent/llm/test_model_info_caps.py +++ b/tests/unit/fast_agent/llm/test_model_info_caps.py @@ -1,6 +1,7 @@ import pathlib import sys import types +from typing import TYPE_CHECKING, cast sys.path.append(str(pathlib.Path(__file__).resolve().parents[4] / "src")) @@ -11,14 +12,17 @@ class AgentCard: # minimal stub for imports pass - types_module.AgentCard = AgentCard - a2a_module.types = types_module + setattr(types_module, "AgentCard", AgentCard) + setattr(a2a_module, "types", types_module) sys.modules["a2a"] = a2a_module sys.modules["a2a.types"] = types_module from fast_agent.llm.model_info import ModelInfo from fast_agent.llm.provider_types import Provider +if TYPE_CHECKING: + from fast_agent.interfaces import FastAgentLLMProtocol + class DummyLLM: def __init__(self, model: str, provider: Provider = Provider.GOOGLE) -> None: @@ -44,7 +48,7 @@ def test_model_alias_capabilities_match_canonical() -> None: def test_model_info_from_llm_uses_canonical_name() -> None: - info = ModelInfo.from_llm(DummyLLM("gemini25")) + info = ModelInfo.from_llm(cast("FastAgentLLMProtocol", DummyLLM("gemini25"))) assert info is not None assert info.name == "gemini-2.5-flash-preview-09-2025" assert info.tdv_flags == (True, True, True) @@ -52,7 +56,7 @@ def test_model_info_from_llm_uses_canonical_name() -> None: def test_model_info_from_agent_llm_capabilities() -> None: agent = DummyAgent("gemini-2.5-pro", provider=Provider.GOOGLE) - info = ModelInfo.from_llm(agent.llm) + info = ModelInfo.from_llm(cast("FastAgentLLMProtocol", agent.llm)) assert info is not None assert info.name == "gemini-2.5-pro" assert info.tdv_flags == (True, True, True) diff --git a/tests/unit/fast_agent/llm/test_prepare_arguments.py b/tests/unit/fast_agent/llm/test_prepare_arguments.py index 8945abe1c..c6a050607 100644 --- a/tests/unit/fast_agent/llm/test_prepare_arguments.py +++ b/tests/unit/fast_agent/llm/test_prepare_arguments.py @@ -11,8 +11,9 @@ class StubLLM(FastAgentLLM): """Minimal implementation of FastAgentLLM for testing purposes""" - def __init__(self, *args, **kwargs): - super().__init__(provider=Provider.FAST_AGENT, *args, **kwargs) + def __init__(self, **kwargs: Any) -> None: + kwargs.pop("provider", None) + super().__init__(provider=Provider.FAST_AGENT, **kwargs) async def _apply_prompt_provider_specific( self, @@ -198,7 +199,7 @@ def test_none_values_not_included(self): llm = StubLLM() base_args = {"model": "test-model"} - params = RequestParams(temperature=None, top_p=0.9) + params = RequestParams(temperature=None, metadata={"top_p": 0.9}) result = llm.prepare_provider_arguments(base_args, params) diff --git a/tests/unit/fast_agent/llm/test_provider_key_manager_hf.py b/tests/unit/fast_agent/llm/test_provider_key_manager_hf.py index beb7c9b31..ffbaed45d 100644 --- a/tests/unit/fast_agent/llm/test_provider_key_manager_hf.py +++ b/tests/unit/fast_agent/llm/test_provider_key_manager_hf.py @@ -50,7 +50,7 @@ def test_get_api_key_from_config(): """Test getting HuggingFace API key from config.""" original = _set_hf_token(None) try: - config = Settings(huggingface=HuggingFaceSettings(api_key="hf_config_token")) + config = Settings(hf=HuggingFaceSettings(api_key="hf_config_token")) api_key = ProviderKeyManager.get_api_key("hf", config) assert api_key == "hf_config_token" finally: @@ -61,7 +61,7 @@ def test_config_takes_precedence_over_env(): """Test that config API key takes precedence over environment variable.""" original = _set_hf_token("hf_env_token") try: - config = Settings(huggingface=HuggingFaceSettings(api_key="hf_config_priority")) + config = Settings(hf=HuggingFaceSettings(api_key="hf_config_priority")) api_key = ProviderKeyManager.get_api_key("hf", config) assert api_key == "hf_config_priority" finally: diff --git a/tests/unit/fast_agent/llm/test_sampling_converter.py b/tests/unit/fast_agent/llm/test_sampling_converter.py index 92fe0c1b5..ab5e53992 100644 --- a/tests/unit/fast_agent/llm/test_sampling_converter.py +++ b/tests/unit/fast_agent/llm/test_sampling_converter.py @@ -14,6 +14,16 @@ from fast_agent.types import PromptMessageExtended +def _text(block: object) -> TextContent: + assert isinstance(block, TextContent) + return block + + +def _image(block: object) -> ImageContent: + assert isinstance(block, ImageContent) + return block + + class TestSamplingConverter: """Tests for SamplingConverter""" @@ -29,8 +39,8 @@ def test_sampling_message_to_prompt_message_text(self): # Verify conversion assert prompt_message.role == "user" assert len(prompt_message.content) == 1 - assert prompt_message.content[0].type == "text" - assert prompt_message.content[0].text == "Hello, world!" + assert _text(prompt_message.content[0]).type == "text" + assert _text(prompt_message.content[0]).text == "Hello, world!" def test_sampling_message_to_prompt_message_image(self): """Test converting an image SamplingMessage to PromptMessageExtended""" @@ -46,9 +56,10 @@ def test_sampling_message_to_prompt_message_image(self): # Verify conversion assert prompt_message.role == "user" assert len(prompt_message.content) == 1 - assert prompt_message.content[0].type == "image" - assert prompt_message.content[0].data == "base64_encoded_image_data" - assert prompt_message.content[0].mimeType == "image/png" + image_block = _image(prompt_message.content[0]) + assert image_block.type == "image" + assert image_block.data == "base64_encoded_image_data" + assert image_block.mimeType == "image/png" def test_convert_messages(self): """Test converting multiple SamplingMessages to PromptMessageExtended objects""" @@ -67,13 +78,13 @@ def test_convert_messages(self): # Verify each message was converted correctly assert prompt_messages[0].role == "user" - assert prompt_messages[0].content[0].text == "Hello" + assert _text(prompt_messages[0].content[0]).text == "Hello" assert prompt_messages[1].role == "assistant" - assert prompt_messages[1].content[0].text == "Hi there" + assert _text(prompt_messages[1].content[0]).text == "Hi there" assert prompt_messages[2].role == "user" - assert prompt_messages[2].content[0].text == "How are you?" + assert _text(prompt_messages[2].content[0]).text == "How are you?" def test_convert_messages_with_mixed_content_types(self): """Test converting messages with different content types""" @@ -99,14 +110,15 @@ def test_convert_messages_with_mixed_content_types(self): # First message (text) assert prompt_messages[0].role == "user" - assert prompt_messages[0].content[0].type == "text" - assert prompt_messages[0].content[0].text == "What's in this image?" + assert _text(prompt_messages[0].content[0]).type == "text" + assert _text(prompt_messages[0].content[0]).text == "What's in this image?" # Second message (image) assert prompt_messages[1].role == "user" - assert prompt_messages[1].content[0].type == "image" - assert prompt_messages[1].content[0].data == "base64_encoded_image_data" - assert prompt_messages[1].content[0].mimeType == "image/png" + image_block = _image(prompt_messages[1].content[0]) + assert image_block.type == "image" + assert image_block.data == "base64_encoded_image_data" + assert image_block.mimeType == "image/png" def test_extract_request_params_full(self): """Test extracting RequestParams from CreateMessageRequestParams with all fields""" @@ -160,8 +172,8 @@ def test_error_result(self): # Verify result assert isinstance(result, CreateMessageResult) assert result.role == "assistant" - assert result.content.type == "text" - assert result.content.text == "Error in sampling: Test error" + assert _text(result.content).type == "text" + assert _text(result.content).text == "Error in sampling: Test error" assert result.model == model assert result.stopReason == "error" @@ -191,7 +203,9 @@ def test_sampling_message_with_tool_result(self): assert prompt_message.role == "user" assert prompt_message.tool_results is not None assert "call_123" in prompt_message.tool_results - assert prompt_message.tool_results["call_123"].content[0].text == "Tool result: 42" + tool_content = prompt_message.tool_results["call_123"].content + assert tool_content is not None + assert _text(tool_content[0]).text == "Tool result: 42" def test_sampling_message_with_tool_use(self): """Test converting a SamplingMessage with ToolUseContent (assistant response)""" diff --git a/tests/unit/fast_agent/llm/test_usage_tracking.py b/tests/unit/fast_agent/llm/test_usage_tracking.py index 7f03233b6..18ac5d37c 100644 --- a/tests/unit/fast_agent/llm/test_usage_tracking.py +++ b/tests/unit/fast_agent/llm/test_usage_tracking.py @@ -255,6 +255,7 @@ def test_cache_hit_rate_calculation(): # Total cache: 300 (anthropic read) + 200 (openai hit) = 500 # Hit rate: 500 / (2100 + 500) * 100 = 500/2600 = 19.23% expected_hit_rate = 500 / (2100 + 500) * 100 + assert accumulator.cache_hit_rate is not None assert abs(accumulator.cache_hit_rate - expected_hit_rate) < 0.01 # Test with no input tokens diff --git a/tests/unit/fast_agent/mcp/prompts/test_template_multipart_integration.py b/tests/unit/fast_agent/mcp/prompts/test_template_multipart_integration.py index 43781efe3..d2cb790b8 100644 --- a/tests/unit/fast_agent/mcp/prompts/test_template_multipart_integration.py +++ b/tests/unit/fast_agent/mcp/prompts/test_template_multipart_integration.py @@ -19,6 +19,11 @@ ) +def _text(block: object) -> TextContent: + assert isinstance(block, TextContent) + return block + + class TestTemplateMultipartIntegration: """Tests for integration between PromptTemplate and PromptMessageExtended.""" @@ -45,13 +50,17 @@ def test_template_to_extended_conversion(self): assert len(multiparts) == 2 assert multiparts[0].role == "user" assert len(multiparts[0].content) == 1 - assert multiparts[0].content[0].type == "text" - assert "Hello, I'm trying to learn about {{topic}}." in multiparts[0].content[0].text + assert _text(multiparts[0].content[0]).type == "text" + assert "Hello, I'm trying to learn about {{topic}}." in _text( + multiparts[0].content[0] + ).text assert multiparts[1].role == "assistant" assert len(multiparts[1].content) == 1 - assert multiparts[1].content[0].type == "text" - assert "I'd be happy to help you learn about {{topic}}!" in multiparts[1].content[0].text + assert _text(multiparts[1].content[0]).type == "text" + assert "I'd be happy to help you learn about {{topic}}!" in _text( + multiparts[1].content[0] + ).text def test_template_with_substitutions_to_extended(self): """Test applying substitutions to a template and converting to extended.""" @@ -72,13 +81,14 @@ def test_template_with_substitutions_to_extended(self): assert len(multiparts) == 2 assert multiparts[0].role == "user" assert ( - "Hello, I'm trying to learn about Python programming." in multiparts[0].content[0].text + "Hello, I'm trying to learn about Python programming." + in _text(multiparts[0].content[0]).text ) assert multiparts[1].role == "assistant" assert ( "I'd be happy to help you learn about Python programming!" - in multiparts[1].content[0].text + in _text(multiparts[1].content[0]).text ) def test_multipart_to_template_conversion(self): @@ -183,15 +193,21 @@ def test_save_and_load_from_file(self, temp_delimited_file): # Check user message assert loaded_messages[0].role == "user" assert len(loaded_messages[0].content) == 1 - assert loaded_messages[0].content[0].type == "text" - assert "Can you explain quantum physics?" in loaded_messages[0].content[0].text + assert _text(loaded_messages[0].content[0]).type == "text" + assert "Can you explain quantum physics?" in _text( + loaded_messages[0].content[0] + ).text # Check assistant message assert loaded_messages[1].role == "assistant" assert len(loaded_messages[1].content) == 1 - assert loaded_messages[1].content[0].type == "text" - assert "Quantum physics is fascinating" in loaded_messages[1].content[0].text - assert "behavior of matter" in loaded_messages[1].content[0].text.lower() + assert _text(loaded_messages[1].content[0]).type == "text" + assert "Quantum physics is fascinating" in _text( + loaded_messages[1].content[0] + ).text + assert "behavior of matter" in _text( + loaded_messages[1].content[0] + ).text.lower() def test_template_loader_integration(self, temp_delimited_file): """Test integration with PromptTemplateLoader.""" diff --git a/tests/unit/fast_agent/mcp/test_agent_server_tool_description.py b/tests/unit/fast_agent/mcp/test_agent_server_tool_description.py index 391f7e85d..c3d582952 100644 --- a/tests/unit/fast_agent/mcp/test_agent_server_tool_description.py +++ b/tests/unit/fast_agent/mcp/test_agent_server_tool_description.py @@ -1,10 +1,14 @@ import asyncio from contextlib import AsyncExitStack +from typing import TYPE_CHECKING, cast from fast_agent.core.agent_app import AgentApp from fast_agent.core.fastagent import AgentInstance from fast_agent.mcp.server.agent_server import AgentMCPServer +if TYPE_CHECKING: + from fast_agent.interfaces import AgentProtocol + class _DummyLLM: def __init__(self) -> None: @@ -24,7 +28,7 @@ async def shutdown(self) -> None: def test_tool_description_supports_agent_placeholder(): async def create_instance() -> AgentInstance: - agent = _DummyAgent() + agent = cast("AgentProtocol", _DummyAgent()) app = AgentApp({"worker": agent}) return AgentInstance(app=app, agents={"worker": agent}) @@ -47,7 +51,7 @@ async def dispose_instance(instance: AgentInstance) -> None: def test_tool_description_defaults_when_not_provided(): async def create_instance() -> AgentInstance: - agent = _DummyAgent() + agent = cast("AgentProtocol", _DummyAgent()) app = AgentApp({"writer": agent}) return AgentInstance(app=app, agents={"writer": agent}) @@ -79,7 +83,7 @@ async def _exercise_request_scope(): async def create_instance() -> AgentInstance: nonlocal create_count create_count += 1 - agent = _DummyAgent() + agent = cast("AgentProtocol", _DummyAgent()) app = AgentApp({"worker": agent}) return AgentInstance(app=app, agents={"worker": agent}) @@ -127,7 +131,7 @@ async def _exercise_connection_scope(): async def create_instance() -> AgentInstance: nonlocal create_count create_count += 1 - agent = _DummyAgent() + agent = cast("AgentProtocol", _DummyAgent()) app = AgentApp({"worker": agent}) return AgentInstance(app=app, agents={"worker": agent}) diff --git a/tests/unit/fast_agent/mcp/test_cimd.py b/tests/unit/fast_agent/mcp/test_cimd.py index 0855364da..2b043716d 100644 --- a/tests/unit/fast_agent/mcp/test_cimd.py +++ b/tests/unit/fast_agent/mcp/test_cimd.py @@ -238,6 +238,7 @@ def test_callback_server_uses_loopback_ip(self): assert server.actual_port is not None assert server.actual_port > 0 # The server is bound to 127.0.0.1 + assert server._server is not None assert server._server.server_address[0] == "127.0.0.1" finally: server.stop() diff --git a/tests/unit/fast_agent/mcp/test_hf_auth.py b/tests/unit/fast_agent/mcp/test_hf_auth.py index 6a4de05ba..1d5e82303 100644 --- a/tests/unit/fast_agent/mcp/test_hf_auth.py +++ b/tests/unit/fast_agent/mcp/test_hf_auth.py @@ -452,6 +452,7 @@ def test_respects_existing_authorization_completely(self): # Should return exact same headers, no modification assert result == existing_headers + assert result is not None assert result["Authorization"] == "Bearer user_provided_token" finally: _restore_hf_token(original) diff --git a/tests/unit/fast_agent/mcp/test_mcp_aggregator_skybridge.py b/tests/unit/fast_agent/mcp/test_mcp_aggregator_skybridge.py index 582a26690..22304759c 100644 --- a/tests/unit/fast_agent/mcp/test_mcp_aggregator_skybridge.py +++ b/tests/unit/fast_agent/mcp/test_mcp_aggregator_skybridge.py @@ -5,6 +5,7 @@ import types from pathlib import Path from types import SimpleNamespace +from typing import Any from unittest.mock import AsyncMock from mcp.types import Tool @@ -43,6 +44,16 @@ class StrEnum(str, enum.Enum): NamespacedTool = _module.NamespacedTool +def _tool_with_meta(name: str, input_schema: dict[str, Any], meta: dict[str, Any]) -> Tool: + return Tool.model_validate( + { + "name": name, + "inputSchema": input_schema, + "_meta": meta, + } + ) + + def _create_aggregator() -> MCPAggregator: """Create an aggregator instance suitable for unit testing.""" aggregator = MCPAggregator( @@ -62,10 +73,10 @@ def test_skybridge_detection_marks_valid_resources() -> None: aggregator.server_supports_feature = AsyncMock(return_value=True) # type: ignore[attr-defined] aggregator._server_to_tool_map["test"] = [ NamespacedTool( - tool=Tool( + tool=_tool_with_meta( name="tool_a", - inputSchema={"type": "object"}, - _meta={"openai/outputTemplate": "ui://component/app"}, + input_schema={"type": "object"}, + meta={"openai/outputTemplate": "ui://component/app"}, ), server_name="test", namespaced_tool_name="test.tool_a", @@ -105,10 +116,10 @@ def test_skybridge_detection_warns_on_invalid_mime() -> None: aggregator.server_supports_feature = AsyncMock(return_value=True) # type: ignore[attr-defined] aggregator._server_to_tool_map["test"] = [ NamespacedTool( - tool=Tool( + tool=_tool_with_meta( name="tool_a", - inputSchema={"type": "object"}, - _meta={"openai/outputTemplate": "ui://component/app"}, + input_schema={"type": "object"}, + meta={"openai/outputTemplate": "ui://component/app"}, ), server_name="test", namespaced_tool_name="test.tool_a", @@ -154,10 +165,10 @@ def test_skybridge_detection_handles_missing_resources_capability() -> None: aggregator.server_supports_feature = AsyncMock(return_value=False) # type: ignore[attr-defined] aggregator._server_to_tool_map["test"] = [ NamespacedTool( - tool=Tool( + tool=_tool_with_meta( name="tool_a", - inputSchema={"type": "object"}, - _meta={"openai/outputTemplate": "ui://component/app"}, + input_schema={"type": "object"}, + meta={"openai/outputTemplate": "ui://component/app"}, ), server_name="test", namespaced_tool_name="test.tool_a", @@ -179,10 +190,10 @@ def test_list_tools_marks_skybridge_meta() -> None: aggregator = _create_aggregator() aggregator.initialized = True - tool = Tool( + tool = _tool_with_meta( name="tool_a", - inputSchema={"type": "object"}, - _meta={"openai/outputTemplate": "ui://component/app"}, + input_schema={"type": "object"}, + meta={"openai/outputTemplate": "ui://component/app"}, ) namespaced = NamespacedTool( diff --git a/tests/unit/fast_agent/mcp/test_prompt_format_utils.py b/tests/unit/fast_agent/mcp/test_prompt_format_utils.py index 2933b9a2c..efe03247d 100644 --- a/tests/unit/fast_agent/mcp/test_prompt_format_utils.py +++ b/tests/unit/fast_agent/mcp/test_prompt_format_utils.py @@ -13,6 +13,7 @@ TextContent, TextResourceContents, ) +from pydantic import AnyUrl from fast_agent.mcp.prompt_message_extended import PromptMessageExtended from fast_agent.mcp.prompt_serialization import ( @@ -23,6 +24,21 @@ ) +def _text(block: object) -> TextContent: + assert isinstance(block, TextContent) + return block + + +def _resource(block: object) -> EmbeddedResource: + assert isinstance(block, EmbeddedResource) + return block + + +def _resource_text(resource: object) -> str: + assert isinstance(resource, TextResourceContents) + return resource.text + + class TestPromptFormatUtils: """Tests for the prompt_format_utils module.""" @@ -37,7 +53,7 @@ def test_multipart_with_resources_to_delimited(self): EmbeddedResource( type="resource", resource=TextResourceContents( - uri="resource://code.py", + uri=AnyUrl("resource://code.py"), mimeType="text/x-python", text='print("Hello, World!")', ), @@ -54,7 +70,7 @@ def test_multipart_with_resources_to_delimited(self): EmbeddedResource( type="resource", resource=TextResourceContents( - uri="resource://improved_code.py", + uri=AnyUrl("resource://improved_code.py"), mimeType="text/x-python", text='def main():\n print("Hello, World!")\n\nif __name__ == "__main__":\n main()', ), @@ -136,21 +152,23 @@ def test_delimited_with_resources_to_extended(self): assert len(messages) == 2 assert messages[0].role == "user" assert len(messages[0].content) == 2 # Text and resource - assert messages[0].content[0].type == "text" - assert "Here's a CSS file" in messages[0].content[0].text - assert messages[0].content[1].type == "resource" - assert str(messages[0].content[1].resource.uri) == "resource://styles.css" - assert messages[0].content[1].resource.mimeType == "text/css" - assert messages[0].content[1].resource.text == "body { color: black; }" + assert _text(messages[0].content[0]).type == "text" + assert "Here's a CSS file" in _text(messages[0].content[0]).text + resource = _resource(messages[0].content[1]) + assert resource.type == "resource" + assert str(resource.resource.uri) == "resource://styles.css" + assert resource.resource.mimeType == "text/css" + assert _resource_text(resource.resource) == "body { color: black; }" assert messages[1].role == "assistant" assert len(messages[1].content) == 2 # Text and resource - assert messages[1].content[0].type == "text" - assert "I've reviewed your CSS" in messages[1].content[0].text - assert messages[1].content[1].type == "resource" - assert str(messages[1].content[1].resource.uri) == "resource://improved_styles.css" - assert messages[1].content[1].resource.mimeType == "text/css" - assert messages[1].content[1].resource.text == "body { color: #000; }" + assert _text(messages[1].content[0]).type == "text" + assert "I've reviewed your CSS" in _text(messages[1].content[0]).text + resource = _resource(messages[1].content[1]) + assert resource.type == "resource" + assert str(resource.resource.uri) == "resource://improved_styles.css" + assert resource.resource.mimeType == "text/css" + assert _resource_text(resource.resource) == "body { color: #000; }" def test_multiple_resources_in_one_message(self): """Test handling multiple resources in a single message.""" @@ -162,7 +180,7 @@ def test_multiple_resources_in_one_message(self): EmbeddedResource( type="resource", resource=TextResourceContents( - uri="resource://data1.csv", + uri=AnyUrl("resource://data1.csv"), mimeType="text/csv", text="id,name,value\n1,A,10\n2,B,20", ), @@ -170,7 +188,7 @@ def test_multiple_resources_in_one_message(self): EmbeddedResource( type="resource", resource=TextResourceContents( - uri="resource://data2.csv", + uri=AnyUrl("resource://data2.csv"), mimeType="text/csv", text="id,name,value\n3,C,30\n4,D,40", ), @@ -212,18 +230,20 @@ def test_multiple_resources_in_one_message(self): assert len(messages) == 1 assert messages[0].role == "user" assert len(messages[0].content) == 3 # Text and two resources - assert messages[0].content[0].type == "text" - assert messages[0].content[1].type == "resource" + assert _text(messages[0].content[0]).type == "text" + assert _resource(messages[0].content[1]).type == "resource" assert messages[0].content[2].type == "resource" # Verify resource content is preserved - assert str(messages[0].content[1].resource.uri) == "resource://data1.csv" - assert messages[0].content[1].resource.mimeType == "text/csv" - assert "id,name,value" in messages[0].content[1].resource.text + resource = _resource(messages[0].content[1]) + assert str(resource.resource.uri) == "resource://data1.csv" + assert resource.resource.mimeType == "text/csv" + assert "id,name,value" in _resource_text(resource.resource) - assert str(messages[0].content[2].resource.uri) == "resource://data2.csv" - assert messages[0].content[2].resource.mimeType == "text/csv" - assert "id,name,value" in messages[0].content[2].resource.text + resource = _resource(messages[0].content[2]) + assert str(resource.resource.uri) == "resource://data2.csv" + assert resource.resource.mimeType == "text/csv" + assert "id,name,value" in _resource_text(resource.resource) def test_image_handling(self): """Test handling image content in multipart messages.""" @@ -290,7 +310,7 @@ def test_save_and_load_with_resources(self, temp_resource_file): EmbeddedResource( type="resource", resource=TextResourceContents( - uri="resource://config.json", + uri=AnyUrl("resource://config.json"), mimeType="application/json", text='{"key": "value"}', ), @@ -309,9 +329,10 @@ def test_save_and_load_with_resources(self, temp_resource_file): assert len(loaded_messages) == 1 assert loaded_messages[0].role == "user" assert len(loaded_messages[0].content) == 2 # Text and resource - assert loaded_messages[0].content[0].type == "text" - assert loaded_messages[0].content[1].type == "resource" - assert str(loaded_messages[0].content[1].resource.uri) == "resource://config.json" + assert _text(loaded_messages[0].content[0]).type == "text" + resource = _resource(loaded_messages[0].content[1]) + assert resource.type == "resource" + assert str(resource.resource.uri) == "resource://config.json" def test_round_trip_with_mime_types(self): """Test round-trip conversion preserving MIME type information.""" @@ -324,7 +345,7 @@ def test_round_trip_with_mime_types(self): EmbeddedResource( type="resource", resource=TextResourceContents( - uri="resource://script.js", + uri=AnyUrl("resource://script.js"), mimeType="application/javascript", text="function hello() { return 'Hello!'; }", ), @@ -332,7 +353,7 @@ def test_round_trip_with_mime_types(self): EmbeddedResource( type="resource", resource=TextResourceContents( - uri="resource://style.css", + uri=AnyUrl("resource://style.css"), mimeType="text/css", text="body { color: blue; }", ), @@ -360,6 +381,6 @@ def test_round_trip_with_mime_types(self): assert len(resources) == 2 # Resource URIs should be preserved - resource_uris = [str(resource.resource.uri) for resource in resources] + resource_uris = [str(_resource(resource).resource.uri) for resource in resources] assert "resource://script.js" in resource_uris assert "resource://style.css" in resource_uris diff --git a/tests/unit/fast_agent/mcp/test_prompt_message_multipart.py b/tests/unit/fast_agent/mcp/test_prompt_message_multipart.py index cf498f138..95c67cfd8 100644 --- a/tests/unit/fast_agent/mcp/test_prompt_message_multipart.py +++ b/tests/unit/fast_agent/mcp/test_prompt_message_multipart.py @@ -13,6 +13,16 @@ from fast_agent.mcp.prompt_message_extended import PromptMessageExtended +def _text(block: object) -> TextContent: + assert isinstance(block, TextContent) + return block + + +def _image(block: object) -> ImageContent: + assert isinstance(block, ImageContent) + return block + + class TestPromptMessageExtended: """Tests for the PromptMessageExtended class.""" @@ -31,8 +41,8 @@ def test_from_prompt_messages_with_single_role(self): assert len(result) == 1 assert result[0].role == "user" assert len(result[0].content) == 2 - assert result[0].content[0].text == "Hello" - assert result[0].content[1].text == "How are you?" + assert _text(result[0].content[0]).text == "Hello" + assert _text(result[0].content[1]).text == "How are you?" def test_from_prompt_messages_with_multiple_roles(self): """Test converting a sequence of PromptMessages with different roles.""" @@ -54,9 +64,9 @@ def test_from_prompt_messages_with_multiple_roles(self): assert len(result[0].content) == 1 assert len(result[1].content) == 1 assert len(result[2].content) == 1 - assert result[0].content[0].text == "Hello" - assert result[1].content[0].text == "Hi there!" - assert result[2].content[0].text == "How are you?" + assert _text(result[0].content[0]).text == "Hello" + assert _text(result[1].content[0]).text == "Hi there!" + assert _text(result[2].content[0]).text == "How are you?" def test_from_prompt_messages_with_mixed_content_types(self): """Test converting messages with mixed content types (text and image).""" @@ -80,10 +90,11 @@ def test_from_prompt_messages_with_mixed_content_types(self): assert len(result) == 1 assert result[0].role == "user" assert len(result[0].content) == 2 - assert result[0].content[0].text == "Look at this image:" - assert result[0].content[1].type == "image" - assert result[0].content[1].data == "base64_encoded_image_data" - assert result[0].content[1].mimeType == "image/png" + assert _text(result[0].content[0]).text == "Look at this image:" + image_block = _image(result[0].content[1]) + assert image_block.type == "image" + assert image_block.data == "base64_encoded_image_data" + assert image_block.mimeType == "image/png" def test_to_prompt_messages(self): """Test converting a PromptMessageExtended back to PromptMessages.""" @@ -103,8 +114,8 @@ def test_to_prompt_messages(self): assert len(result) == 2 assert result[0].role == "user" assert result[1].role == "user" - assert result[0].content.text == "Hello" - assert result[1].content.text == "How are you?" + assert _text(result[0].content).text == "Hello" + assert _text(result[1].content).text == "How are you?" def test_parse_get_prompt_result(self): """Test parsing a GetPromptResult into PromptMessageExtended objects.""" @@ -129,9 +140,9 @@ def test_parse_get_prompt_result(self): assert len(multiparts[0].content) == 1 assert len(multiparts[1].content) == 1 assert len(multiparts[2].content) == 1 - assert multiparts[0].content[0].text == "Hello" - assert multiparts[1].content[0].text == "Hi there!" - assert multiparts[2].content[0].text == "How are you?" + assert _text(multiparts[0].content[0]).text == "Hello" + assert _text(multiparts[1].content[0]).text == "Hi there!" + assert _text(multiparts[2].content[0]).text == "How are you?" def test_empty_messages(self): """Test handling of empty message lists.""" @@ -165,7 +176,7 @@ def test_round_trip_conversion(self): assert len(result) == len(messages) for i in range(len(messages)): assert result[i].role == messages[i].role - assert result[i].content.text == messages[i].content.text + assert _text(result[i].content).text == _text(messages[i].content).text def test_from_get_prompt_result(self): """Test from_get_prompt_result method with error handling.""" diff --git a/tests/unit/fast_agent/mcp/test_prompt_multipart_conversion.py b/tests/unit/fast_agent/mcp/test_prompt_multipart_conversion.py index 3ada1d355..42c2c66b8 100644 --- a/tests/unit/fast_agent/mcp/test_prompt_multipart_conversion.py +++ b/tests/unit/fast_agent/mcp/test_prompt_multipart_conversion.py @@ -13,6 +13,11 @@ from fast_agent.mcp.prompts.prompt_template import PromptTemplateLoader +def _text(block: object) -> TextContent: + assert isinstance(block, TextContent) + return block + + def test_resource_message_role_merging(): """ Test that demonstrates how resources cause role merging issues. @@ -153,7 +158,7 @@ def test_playback_pattern_with_simple_messages(): assert multipart[3].role == "assistant" # Check content is preserved - assert multipart[0].content[0].text == "user1" - assert multipart[1].content[0].text == "assistant1" - assert multipart[2].content[0].text == "user2" - assert multipart[3].content[0].text == "assistant2" + assert _text(multipart[0].content[0]).text == "user1" + assert _text(multipart[1].content[0]).text == "assistant1" + assert _text(multipart[2].content[0]).text == "user2" + assert _text(multipart[3].content[0]).text == "assistant2" diff --git a/tests/unit/fast_agent/mcp/test_prompt_serialization.py b/tests/unit/fast_agent/mcp/test_prompt_serialization.py index 2a4b50652..4284de4b8 100644 --- a/tests/unit/fast_agent/mcp/test_prompt_serialization.py +++ b/tests/unit/fast_agent/mcp/test_prompt_serialization.py @@ -3,6 +3,7 @@ """ from mcp.types import EmbeddedResource, ImageContent, TextContent, TextResourceContents +from pydantic import AnyUrl from fast_agent.mcp.prompt_message_extended import PromptMessageExtended from fast_agent.mcp.prompt_serialization import ( @@ -26,7 +27,7 @@ def test_json_serialization_and_deserialization(self): EmbeddedResource( type="resource", resource=TextResourceContents( - uri="resource://data.json", + uri=AnyUrl("resource://data.json"), mimeType="application/json", text='{"key": "value"}', ), @@ -63,20 +64,30 @@ def test_json_serialization_and_deserialization(self): # Check first message assert len(parsed_messages[0].content) == 2 - assert parsed_messages[0].content[0].type == "text" - assert parsed_messages[0].content[0].text == "Here's a resource:" - assert parsed_messages[0].content[1].type == "resource" - assert str(parsed_messages[0].content[1].resource.uri) == "resource://data.json" - assert parsed_messages[0].content[1].resource.mimeType == "application/json" - assert parsed_messages[0].content[1].resource.text == '{"key": "value"}' + first_block = parsed_messages[0].content[0] + assert isinstance(first_block, TextContent) + assert first_block.type == "text" + assert first_block.text == "Here's a resource:" + resource_block = parsed_messages[0].content[1] + assert isinstance(resource_block, EmbeddedResource) + assert resource_block.type == "resource" + resource = resource_block.resource + assert isinstance(resource, TextResourceContents) + assert str(resource.uri) == "resource://data.json" + assert resource.mimeType == "application/json" + assert resource.text == '{"key": "value"}' # Check second message assert len(parsed_messages[1].content) == 2 - assert parsed_messages[1].content[0].type == "text" - assert parsed_messages[1].content[0].text == "I've processed your resource." - assert parsed_messages[1].content[1].type == "image" - assert parsed_messages[1].content[1].data == "base64EncodedImage" - assert parsed_messages[1].content[1].mimeType == "image/jpeg" + assistant_block = parsed_messages[1].content[0] + assert isinstance(assistant_block, TextContent) + assert assistant_block.type == "text" + assert assistant_block.text == "I've processed your resource." + image_block = parsed_messages[1].content[1] + assert isinstance(image_block, ImageContent) + assert image_block.type == "image" + assert image_block.data == "base64EncodedImage" + assert image_block.mimeType == "image/jpeg" def test_multipart_to_delimited_format(self): """Test converting PromptMessageExtended to delimited format for saving.""" @@ -122,7 +133,7 @@ def test_multipart_with_resources_to_delimited_format(self): EmbeddedResource( type="resource", resource=TextResourceContents( - uri="resource://example.py", + uri=AnyUrl("resource://example.py"), mimeType="text/x-python", text="def hello():\n print('Hello, world!')", ), diff --git a/tests/unit/fast_agent/mcp/test_transport_factory_validation.py b/tests/unit/fast_agent/mcp/test_transport_factory_validation.py index 07720dac5..8c53b9a94 100644 --- a/tests/unit/fast_agent/mcp/test_transport_factory_validation.py +++ b/tests/unit/fast_agent/mcp/test_transport_factory_validation.py @@ -23,6 +23,7 @@ def test_transport_factory_validation_stdio_without_command(): # We need to access the internal transport_context_factory # This is a bit of a hack but necessary for unit testing config = registry.get_server_config("test_server") + assert config is not None # Simulate what happens inside launch_server def transport_context_factory(): diff --git a/tests/unit/fast_agent/mcp/test_ui_mixin.py b/tests/unit/fast_agent/mcp/test_ui_mixin.py index 7345aad0e..9c5b5a453 100644 --- a/tests/unit/fast_agent/mcp/test_ui_mixin.py +++ b/tests/unit/fast_agent/mcp/test_ui_mixin.py @@ -3,6 +3,7 @@ import pytest from mcp.types import CallToolResult, EmbeddedResource, TextContent, TextResourceContents +from pydantic import AnyUrl from rich.text import Text from fast_agent.agents.agent_types import AgentConfig @@ -85,10 +86,11 @@ def ui_agent(mock_config, mock_context): return agent -def create_ui_resource(uri="ui://test/component", text="Test UI"): +def create_ui_resource(uri: str = "ui://test/component", text: str = "Test UI"): """Helper to create a UI embedded resource.""" return EmbeddedResource( - type="resource", resource=TextResourceContents(uri=uri, mimeType="text/html", text=text) + type="resource", + resource=TextResourceContents(uri=AnyUrl(uri), mimeType="text/html", text=text), ) @@ -238,8 +240,8 @@ def stub_ui_links(resources): def stub_open_browser(links, **kwargs): pass - ui_mixin_module.ui_links_from_channel = stub_ui_links - ui_mixin_module.open_links_in_browser = stub_open_browser + setattr(ui_mixin_module, "ui_links_from_channel", stub_ui_links) + setattr(ui_mixin_module, "open_links_in_browser", stub_open_browser) await ui_agent.show_assistant_message(assistant_msg) @@ -249,8 +251,8 @@ def stub_open_browser(links, **kwargs): finally: # Restore original functions - ui_mixin_module.ui_links_from_channel = original_ui_links_from_channel - ui_mixin_module.open_links_in_browser = original_open_links_in_browser + setattr(ui_mixin_module, "ui_links_from_channel", original_ui_links_from_channel) + setattr(ui_mixin_module, "open_links_in_browser", original_open_links_in_browser) def test_is_ui_embedded_resource(ui_agent): @@ -263,7 +265,7 @@ def test_is_ui_embedded_resource(ui_agent): non_ui = EmbeddedResource( type="resource", resource=TextResourceContents( - uri="http://example.com", mimeType="text/html", text="content" + uri=AnyUrl("http://example.com"), mimeType="text/html", text="content" ), ) assert ui_agent._is_ui_embedded_resource(non_ui) is False diff --git a/tests/unit/fast_agent/tools/test_shell_runtime.py b/tests/unit/fast_agent/tools/test_shell_runtime.py index 35e0224d3..9e6f969c1 100644 --- a/tests/unit/fast_agent/tools/test_shell_runtime.py +++ b/tests/unit/fast_agent/tools/test_shell_runtime.py @@ -8,6 +8,7 @@ from typing import Any import pytest +from mcp.types import TextContent from fast_agent.tools.shell_runtime import ShellRuntime from fast_agent.ui import console @@ -102,6 +103,9 @@ async def test_execute_simple_command() -> None: result = await runtime.execute({"command": "echo hello"}) assert result.isError is False + assert result.content is not None + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) assert "hello" in result.content[0].text assert "exit code" in result.content[0].text @@ -121,6 +125,9 @@ async def test_execute_command_with_exit_code() -> None: result = await runtime.execute({"command": "false"}) assert result.isError is True + assert result.content is not None + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) assert "exit code" in result.content[0].text @@ -140,8 +147,13 @@ async def fast_sleep(_): result = await runtime.execute({"command": "Start-Sleep -Seconds 5"}) - assert signal.CTRL_BREAK_EVENT in process.sent_signals + ctrl_break = getattr(signal, "CTRL_BREAK_EVENT", None) + assert ctrl_break is not None + assert ctrl_break in process.sent_signals assert process.terminated is True assert captured["exec_args"][0].endswith("pwsh.exe") assert result.isError is True + assert result.content is not None + assert result.content[0].type == "text" + assert isinstance(result.content[0], TextContent) assert "(timeout after 0s" in result.content[0].text diff --git a/tests/unit/fast_agent/types/test_message_search.py b/tests/unit/fast_agent/types/test_message_search.py index 5ec468477..8977eb66a 100644 --- a/tests/unit/fast_agent/types/test_message_search.py +++ b/tests/unit/fast_agent/types/test_message_search.py @@ -33,7 +33,9 @@ def test_search_messages_user_scope(): results = search_messages(messages, "error", scope="user") assert len(results) == 1 assert results[0].role == "user" - assert "errors" in results[0].content[0].text + first_block = results[0].content[0] + assert isinstance(first_block, TextContent) + assert "errors" in first_block.text def test_search_messages_assistant_scope(): @@ -79,7 +81,11 @@ def test_search_messages_tool_results_scope(): results = search_messages(messages, r"Job started:", scope="tool_results") assert len(results) == 1 - assert results[0].tool_results["call_1"].content[0].text == "Job started: abc123def" + tool_results = results[0].tool_results + assert tool_results is not None + first_block = tool_results["call_1"].content[0] + assert isinstance(first_block, TextContent) + assert first_block.text == "Job started: abc123def" def test_search_messages_tool_calls_scope(): @@ -110,7 +116,9 @@ def test_search_messages_tool_calls_scope(): # Search for tool name results = search_messages(messages, "create_job", scope="tool_calls") assert len(results) == 1 - assert "call_1" in results[0].tool_calls + tool_calls = results[0].tool_calls + assert tool_calls is not None + assert "call_1" in tool_calls # Search in arguments results = search_messages(messages, "processing", scope="tool_calls") diff --git a/tests/unit/fast_agent/ui/test_streaming_mode_switch.py b/tests/unit/fast_agent/ui/test_streaming_mode_switch.py index f80cd00a3..83f96b1f9 100644 --- a/tests/unit/fast_agent/ui/test_streaming_mode_switch.py +++ b/tests/unit/fast_agent/ui/test_streaming_mode_switch.py @@ -1,3 +1,5 @@ +from typing import Literal + from fast_agent.config import Settings from fast_agent.llm.stream_types import StreamChunk from fast_agent.ui import console @@ -25,7 +27,9 @@ def _restore_console_size(original_width: object | None, original_height: object console.console._height = original_height -def _make_handle(streaming_mode: str = "markdown") -> _StreamingMessageHandle: +def _make_handle( + streaming_mode: Literal["markdown", "plain", "none"] = "markdown", +) -> _StreamingMessageHandle: settings = Settings() settings.logger.streaming = streaming_mode display = ConsoleDisplay(settings) diff --git a/tests/unit/hf_inference_acp/test_apply_model.py b/tests/unit/hf_inference_acp/test_apply_model.py index 794d950f9..d3bb229ba 100644 --- a/tests/unit/hf_inference_acp/test_apply_model.py +++ b/tests/unit/hf_inference_acp/test_apply_model.py @@ -20,8 +20,8 @@ async def test_apply_model_does_not_override_request_params_model(monkeypatch) - pytest.importorskip("ruamel.yaml") _ensure_hf_inference_acp_on_path() - import hf_inference_acp.agents as agents_mod - from hf_inference_acp.agents import HuggingFaceAgent + import hf_inference_acp.agents as agents_mod # ty: ignore[unresolved-import] + from hf_inference_acp.agents import HuggingFaceAgent # ty: ignore[unresolved-import] calls: list[dict] = [] diff --git a/tests/unit/hf_inference_acp/test_wizard_curated_models.py b/tests/unit/hf_inference_acp/test_wizard_curated_models.py index 76c326928..014ac0f66 100644 --- a/tests/unit/hf_inference_acp/test_wizard_curated_models.py +++ b/tests/unit/hf_inference_acp/test_wizard_curated_models.py @@ -17,9 +17,11 @@ async def test_wizard_model_selection_uses_curated_ids() -> None: pytest.importorskip("ruamel.yaml") _ensure_hf_inference_acp_on_path() - from hf_inference_acp.wizard.model_catalog import CURATED_MODELS - from hf_inference_acp.wizard.stages import WizardStage - from hf_inference_acp.wizard.wizard_llm import WizardSetupLLM + from hf_inference_acp.wizard.model_catalog import ( # ty: ignore[unresolved-import] + CURATED_MODELS, + ) + from hf_inference_acp.wizard.stages import WizardStage # ty: ignore[unresolved-import] + from hf_inference_acp.wizard.wizard_llm import WizardSetupLLM # ty: ignore[unresolved-import] llm = WizardSetupLLM() llm._state.first_message = False # skip welcome @@ -30,4 +32,3 @@ async def test_wizard_model_selection_uses_curated_ids() -> None: assert llm._state.selected_model == CURATED_MODELS[0].id assert llm._state.stage == WizardStage.MCP_CONNECT assert "Step 3" in response - diff --git a/typesafe.md b/typesafe.md index 2d1750602..c3525a04d 100644 --- a/typesafe.md +++ b/typesafe.md @@ -95,6 +95,10 @@ Python 3.13+ and current typing guidance. - Use `Self` for fluent APIs and `TypeAlias` for complex, reused types. - Use `type: ignore` only when interacting with third-party APIs that are untyped or known-broken; otherwise prefer `ty: ignore[rule]` with the specific rule. +- For tests, prefer `assert isinstance(x, ConcreteType)` (or `assert x is not None`) to narrow types + instead of `cast(...)` when the runtime behavior enforces the type. +- When dealing with untyped dict-like payloads in tests, add `assert isinstance(payload, dict)` + before passing into helpers instead of casting to `dict[str, Any]`. ## Reproducible Plan