Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/fast_agent/agents/workflow/agents_as_tools_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ class AgentsAsToolsOptions:

history_mode: HistoryMode = HistoryMode.FORK
max_parallel: int | None = None
child_timeout_sec: int | None = None
child_timeout_sec: float | None = None
max_display_instances: int = 20

def __post_init__(self) -> None:
Expand Down
2 changes: 2 additions & 0 deletions src/fast_agent/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def get_request_params(
request_params: RequestParams | None = None,
) -> RequestParams: ...

default_request_params: RequestParams

def add_stream_listener(self, listener: Callable[[StreamChunk], None]) -> Callable[[], None]: ...

def add_tool_stream_listener(
Expand Down
3 changes: 3 additions & 0 deletions src/fast_agent/mcp/mcp_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def MCPPrompt(
Path,
bytes,
ContentBlock,
ResourceContents,
ReadResourceResult,
PromptMessage,
PromptMessageExtended,
Expand Down Expand Up @@ -262,6 +263,7 @@ def User(
Path,
bytes,
ContentBlock,
ResourceContents,
ReadResourceResult,
PromptMessage,
PromptMessageExtended,
Expand All @@ -278,6 +280,7 @@ def Assistant(
Path,
bytes,
ContentBlock,
ResourceContents,
ReadResourceResult,
PromptMessage,
PromptMessageExtended,
Expand Down
32 changes: 28 additions & 4 deletions src/fast_agent/mcp/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Literal, Union

from mcp import CallToolRequest
from mcp.types import ContentBlock, PromptMessage
from mcp.types import ContentBlock, PromptMessage, ReadResourceResult, ResourceContents

from fast_agent.mcp.mcp_content import Assistant, MCPPrompt, User
from fast_agent.types import LlmStopReason, PromptMessageExtended
Expand Down Expand Up @@ -39,7 +39,15 @@ class Prompt:
def user(
cls,
*content_items: Union[
str, Path, bytes, dict, ContentBlock, PromptMessage, PromptMessageExtended
str,
Path,
bytes,
dict,
ContentBlock,
ResourceContents,
ReadResourceResult,
PromptMessage,
PromptMessageExtended,
],
) -> PromptMessageExtended:
"""
Expand All @@ -62,7 +70,15 @@ def user(
def assistant(
cls,
*content_items: Union[
str, Path, bytes, dict, ContentBlock, PromptMessage, PromptMessageExtended
str,
Path,
bytes,
dict,
ContentBlock,
ResourceContents,
ReadResourceResult,
PromptMessage,
PromptMessageExtended,
],
stop_reason: LlmStopReason | None = None,
tool_calls: dict[str, CallToolRequest] | None = None,
Expand Down Expand Up @@ -102,7 +118,15 @@ def assistant(
def message(
cls,
*content_items: Union[
str, Path, bytes, dict, ContentBlock, PromptMessage, PromptMessageExtended
str,
Path,
bytes,
dict,
ContentBlock,
ResourceContents,
ReadResourceResult,
PromptMessage,
PromptMessageExtended,
],
role: Literal["user", "assistant"] = "user",
) -> PromptMessageExtended:
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/llm/test_llm_e2e_reasoning.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def on_chunk(chunk: StreamChunk) -> None:

async def _run_turn(agent: LlmAgent, prompt: str) -> tuple[dict[str, int], list[str], str | None]:
listener, state = _make_stream_tracker()
assert agent.llm is not None
remove = agent.llm.add_stream_listener(listener)
try:
result = await agent.generate(prompt)
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/smoke/tensorzero/test_agent_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def dummy_agent_func():

async with fast.run() as agent_app:
agent_instance = agent_app.default
print(f"\nSending {len(messages_to_send)} messages to agent '{agent_instance._name}'...")
print(f"\nSending {len(messages_to_send)} messages to agent '{agent_instance.name}'...")
for i, msg_text in enumerate(messages_to_send):
print(f"Sending message {i + 1}: '{msg_text}'")
await agent_instance.send(msg_text)
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/smoke/tensorzero/test_simple_agent_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ async def dummy_simple_agent_func():
async with fast.run() as agent_app:
agent_instance = agent_app.simple_default

print(f"\nSending message to agent '{agent_instance._name}': '{message_to_send}'")
print(f"\nSending message to agent '{agent_instance.name}': '{message_to_send}'")
await agent_instance.send(message_to_send)
print(f"Message sent successfully to '{agent_instance._name}'.")
print(f"Message sent successfully to '{agent_instance.name}'.")

print("\nSimple agent interaction smoke test completed successfully.")
3 changes: 2 additions & 1 deletion tests/integration/acp/test_acp_skills_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ async def test_skills_registry_numbered_selection(tmp_path: Path) -> None:
assert "Registry set to" in response
assert get_settings().skills.marketplace_url == marketplace2.as_posix()
# Configured list is preserved
assert len(get_settings().skills.marketplace_urls) == 2
marketplace_urls = get_settings().skills.marketplace_urls or []
assert len(marketplace_urls) == 2

# Invalid number
response = await handler.execute_command("skills", "registry 99")
Expand Down
12 changes: 8 additions & 4 deletions tests/integration/acp/test_acp_status_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import sys
from pathlib import Path
from typing import cast

import pytest
from acp.helpers import text_block
Expand Down Expand Up @@ -37,14 +38,17 @@
def _extract_status_line(meta: object) -> str | None:
if not isinstance(meta, dict):
return None
field_meta = meta.get("field_meta")
meta_dict = cast("dict[str, object]", meta)
field_meta = meta_dict.get("field_meta")
if isinstance(field_meta, dict):
metrics = field_meta.get("openhands.dev/metrics")
field_meta_dict = cast("dict[str, object]", field_meta)
metrics = field_meta_dict.get("openhands.dev/metrics")
else:
metrics = meta.get("openhands.dev/metrics")
metrics = meta_dict.get("openhands.dev/metrics")
if not isinstance(metrics, dict):
return None
status_line = metrics.get("status_line")
metrics_dict = cast("dict[str, object]", metrics)
status_line = metrics_dict.get("status_line")
if isinstance(status_line, str) and status_line.strip():
return status_line
return None
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/acp/test_acp_terminal.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def test_acp_terminal_execution() -> None:

# Manually test terminal lifecycle (client creates ID)
create_result = await client.create_terminal(command="echo test", session_id=session_id)
terminal_id = create_result.terminalId
terminal_id = create_result.terminal_id

# Verify terminal was created with client-generated ID
assert terminal_id == "terminal-1" # First terminal
Expand All @@ -141,7 +141,7 @@ async def test_acp_terminal_execution() -> None:
output = await client.terminal_output(session_id=session_id, terminal_id=terminal_id)
assert "Executed: echo test" in output.output
exit_info = await client.wait_for_terminal_exit(session_id=session_id, terminal_id=terminal_id)
assert exit_info.exitCode == 0
assert exit_info.exit_code == 0

# Release terminal
await client.release_terminal(session_id=session_id, terminal_id=terminal_id)
Expand Down
18 changes: 9 additions & 9 deletions tests/integration/acp/test_acp_terminal_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ async def test_terminal_create_lifecycle() -> None:

# Create first terminal
result1 = await client.create_terminal(command="echo hello", session_id="test-session")
terminal_id1 = result1.terminalId
terminal_id1 = result1.terminal_id

assert terminal_id1 == "terminal-1"
assert len(client.terminals) == 1
assert client.terminals[terminal_id1]["command"] == "echo hello"

# Create second terminal
result2 = await client.create_terminal(command="pwd", session_id="test-session")
terminal_id2 = result2.terminalId
terminal_id2 = result2.terminal_id

assert terminal_id2 == "terminal-2"
assert len(client.terminals) == 2
Expand All @@ -68,7 +68,7 @@ async def test_terminal_output_retrieval() -> None:

# Create terminal
result = await client.create_terminal(command="echo test output", session_id="test-session")
terminal_id = result.terminalId
terminal_id = result.terminal_id

# Get output
output = await client.terminal_output(session_id="test-session", terminal_id=terminal_id)
Expand All @@ -88,14 +88,14 @@ async def test_terminal_wait_for_exit() -> None:

# Create terminal
result = await client.create_terminal(command="echo test", session_id="test-session")
terminal_id = result.terminalId
terminal_id = result.terminal_id

# Wait for exit (immediate in test client)
exit_result = await client.wait_for_terminal_exit(
session_id="test-session", terminal_id=terminal_id
)

assert exit_result.exitCode == 0
assert exit_result.exit_code == 0
assert exit_result.signal is None

# Cleanup
Expand All @@ -110,7 +110,7 @@ async def test_terminal_kill() -> None:

# Create terminal
result = await client.create_terminal(command="sleep 100", session_id="test-session")
terminal_id = result.terminalId
terminal_id = result.terminal_id

# Kill it
await client.kill_terminal(session_id="test-session", terminal_id=terminal_id)
Expand All @@ -123,7 +123,7 @@ async def test_terminal_kill() -> None:
exit_result = await client.wait_for_terminal_exit(
session_id="test-session", terminal_id=terminal_id
)
assert exit_result.exitCode is None
assert exit_result.exit_code is None
assert exit_result.signal == "SIGKILL"

# Cleanup
Expand All @@ -140,7 +140,7 @@ async def test_terminal_release_cleanup() -> None:
terminals = []
for i in range(3):
result = await client.create_terminal(command=f"echo {i}", session_id="test-session")
terminals.append(result.terminalId)
terminals.append(result.terminal_id)

assert len(client.terminals) == 3

Expand Down Expand Up @@ -171,7 +171,7 @@ async def test_terminal_missing_id() -> None:
exit_result = await client.wait_for_terminal_exit(
session_id="test-session", terminal_id="missing"
)
assert exit_result.exitCode is None
assert exit_result.exit_code is None

# Kill non-existent terminal (should not error)
await client.kill_terminal(session_id="test-session", terminal_id="missing")
14 changes: 8 additions & 6 deletions tests/integration/acp/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def queue_permission_cancelled(self) -> None:
def queue_permission_selected(self, option_id: str) -> None:
self.permission_outcomes.append(
RequestPermissionResponse(
outcome=AllowedOutcome(optionId=option_id, outcome="selected")
outcome=AllowedOutcome(option_id=option_id, outcome="selected")
)
)

Expand Down Expand Up @@ -149,7 +149,7 @@ async def create_terminal(
}

# Return the ID we created
return CreateTerminalResponse(terminalId=terminal_id)
return CreateTerminalResponse(terminal_id=terminal_id)

async def terminal_output(
self,
Expand All @@ -161,9 +161,9 @@ async def terminal_output(
terminal = self.terminals.get(terminal_id, {})
exit_code = terminal.get("exit_code")
if isinstance(exit_code, int) and exit_code >= 0:
exit_status = TerminalExitStatus(exitCode=exit_code)
exit_status = TerminalExitStatus(exit_code=exit_code)
elif isinstance(exit_code, int) and exit_code < 0:
exit_status = TerminalExitStatus(exitCode=None, signal="SIGKILL")
exit_status = TerminalExitStatus(exit_code=None, signal="SIGKILL")
else:
exit_status = None

Expand Down Expand Up @@ -194,10 +194,12 @@ async def wait_for_terminal_exit(
terminal = self.terminals.get(terminal_id, {})
exit_code = terminal.get("exit_code")
if isinstance(exit_code, int) and exit_code >= 0:
return WaitForTerminalExitResponse(exitCode=exit_code, signal=None)
return WaitForTerminalExitResponse(exit_code=exit_code, signal=None)

# Unknown or negative exit -> model as killed/terminated with no exit code
return WaitForTerminalExitResponse(exitCode=None, signal="SIGKILL" if exit_code else None)
return WaitForTerminalExitResponse(
exit_code=None, signal="SIGKILL" if exit_code else None
)

async def kill_terminal(
self,
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/acp/test_set_model_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,10 @@ async def test_validate_resolves_aliases_to_hf_models() -> None:
resolved_model = model
break

if hf_alias is None:
if hf_alias is None or resolved_model is None:
pytest.skip("No HF model aliases found in MODEL_ALIASES")
assert hf_alias is not None
assert resolved_model is not None

# Extract the expected model ID from the resolved model
expected_model_id = resolved_model[3:] # Strip "hf."
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/api/mcp_tools_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def shirt_colour() -> str:
@app.tool(name="implementation", description="Returns the Client implementation")
def implementation(ctx: Context) -> str:
assert ctx.session.client_params is not None, "Client params should not be None"
clientInfo = ctx.session.client_params.clientInfo or None

return clientInfo.model_dump_json()
client_info = ctx.session.client_params.clientInfo
assert client_info is not None
return client_info.model_dump_json()


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions tests/integration/api/test_logger_textio.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def test_logger_textio_real_process(test_script_path, logger_io):
)

# Read and process stderr lines
assert process.stderr is not None
for line in process.stderr:
logger_io.write(line)

Expand Down
7 changes: 4 additions & 3 deletions tests/integration/api/test_retry_error_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
class FailingOpenAILLM(OpenAILLM):
"""Test double that always raises an APIError."""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, provider=Provider.OPENAI, **kwargs)
def __init__(self, **kwargs) -> None:
super().__init__(provider=Provider.OPENAI, **kwargs)
self.attempts = 0

async def _apply_prompt_provider_specific(
Expand All @@ -40,7 +40,8 @@ async def test_retry_exhaustion_returns_error_channel():

assert llm.attempts == 1 # no retries when FAST_AGENT_RETRIES=0
assert response.stop_reason == LlmStopReason.ERROR
assert FAST_AGENT_ERROR_CHANNEL in (response.channels or {})
assert response.channels is not None
assert FAST_AGENT_ERROR_CHANNEL in response.channels
error_block = response.channels[FAST_AGENT_ERROR_CHANNEL][0]
assert "request failed" in (get_text(error_block) or "")

Expand Down
5 changes: 3 additions & 2 deletions tests/integration/elicitation/test_elicitation_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ async def custom_elicitation_handler(
"""Test handler that returns predictable responses for integration testing."""
logger.info(f"Test elicitation handler called with: {params.message}")

if params.requestedSchema:
requested_schema = getattr(params, "requestedSchema", None)
if requested_schema:
# Generate test data based on the schema for round-trip verification
properties = params.requestedSchema.get("properties", {})
properties = requested_schema.get("properties", {})
content: dict[str, Any] = {}

# Provide test values for each field
Expand Down
5 changes: 3 additions & 2 deletions tests/integration/elicitation/test_elicitation_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ async def custom_test_elicitation_handler(
"""Test handler that returns predictable responses for integration testing."""
logger.info(f"Test elicitation handler called with: {params.message}")

if params.requestedSchema:
requested_schema = getattr(params, "requestedSchema", None)
if requested_schema:
# Generate test data based on the schema for round-trip verification
properties = params.requestedSchema.get("properties", {})
properties = requested_schema.get("properties", {})
content: dict[str, Any] = {}

# Provide test values for each field
Expand Down
Loading
Loading