diff --git a/.gitignore b/.gitignore index 988cc3f2a..7b0332a11 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -# Byte-compiled / optimized / DLL files +s# Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class @@ -200,3 +200,4 @@ fastagent.jsonl # JetBrains IDEs .idea/ tests/e2e/smoke/base/weather_location.txt +tests/integration/roots/fastagent.jsonl diff --git a/.vscode/fastagent.config.schema.json b/.vscode/fastagent.config.schema.json index 6020d050d..45f2f3cbd 100644 --- a/.vscode/fastagent.config.schema.json +++ b/.vscode/fastagent.config.schema.json @@ -236,6 +236,12 @@ "title": "Truncate Tools", "type": "boolean", "description": "Truncate display of long tool calls" + }, + "enable_markup": { + "default": true, + "title": "Enable Markup", + "type": "boolean", + "description": "Enable markup in console output. Disable for outputs that may conflict with rich console formatting" } }, "title": "LoggerSettings", @@ -348,7 +354,8 @@ "default": "stdio", "enum": [ "stdio", - "sse" + "sse", + "http" ], "title": "Transport", "type": "string", @@ -822,7 +829,8 @@ "http_timeout": 5.0, "show_chat": true, "show_tools": true, - "truncate_tools": true + "truncate_tools": true, + "enable_markup": true }, "description": "Logger settings for the fast-agent application" } diff --git a/src/mcp_agent/cli/commands/README.md b/src/mcp_agent/cli/commands/README.md new file mode 100644 index 000000000..60d9449fd --- /dev/null +++ b/src/mcp_agent/cli/commands/README.md @@ -0,0 +1,68 @@ +# Fast-Agent CLI Commands + +This directory contains the command implementations for the fast-agent CLI. + +## Go Command + +The `go` command allows you to run an interactive agent directly from the command line without +creating a dedicated agent.py file. + +### Usage + +```bash +fast-agent go [OPTIONS] +``` + +### Options + +- `--name TEXT`: Name for the agent (default: "FastAgent CLI") +- `--instruction`, `-i TEXT`: Instruction for the agent (default: "You are a helpful AI Agent.") +- `--config-path`, `-c TEXT`: Path to config file +- `--servers TEXT`: Comma-separated list of server names to enable from config +- `--url TEXT`: Comma-separated list of HTTP/SSE URLs to connect to directly +- `--auth TEXT`: Bearer token for authorization with URL-based servers +- `--model TEXT`: Override the default model (e.g., haiku, sonnet, gpt-4) +- `--message`, `-m TEXT`: Message to send to the agent (skips interactive mode) +- `--prompt-file`, `-p TEXT`: Path to a prompt file to use (either text or JSON) +- `--quiet`: Disable progress display and logging + +### Examples + +```bash +# Basic usage with interactive mode +fast-agent go --model=haiku + +# Specifying servers from configuration +fast-agent go --servers=fetch,filesystem --model=haiku + +# Directly connecting to HTTP/SSE servers via URLs +fast-agent go --url=http://localhost:8001/mcp,http://api.example.com/sse + +# Connecting to an authenticated API endpoint +fast-agent go --url=https://api.example.com/mcp --auth=YOUR_API_TOKEN + +# Non-interactive mode with a single message +fast-agent go --message="What is the weather today?" --model=haiku + +# Using a prompt file +fast-agent go --prompt-file=my-prompt.txt --model=haiku +``` + +### URL Connection Details + +The `--url` parameter allows you to connect directly to HTTP or SSE servers using URLs. + +- URLs must have http or https scheme +- The transport type is determined by the URL path: + - URLs ending with `/sse` are treated as SSE transport + - URLs ending with `/mcp` or automatically appended with `/mcp` are treated as HTTP transport +- Server names are generated automatically based on the hostname, port, and path +- The URL-based servers are added to the agent's configuration and enabled + +### Authentication + +The `--auth` parameter provides authentication for URL-based servers: + +- When provided, it creates an `Authorization: Bearer TOKEN` header for all URL-based servers +- This is commonly used with API endpoints that require authentication +- Example: `fast-agent go --url=https://api.example.com/mcp --auth=12345abcde` \ No newline at end of file diff --git a/src/mcp_agent/cli/commands/go.py b/src/mcp_agent/cli/commands/go.py index 16df8abf1..e06ab76db 100644 --- a/src/mcp_agent/cli/commands/go.py +++ b/src/mcp_agent/cli/commands/go.py @@ -2,16 +2,18 @@ import asyncio import sys -from typing import List, Optional +from typing import Dict, List, Optional import typer +from mcp_agent.cli.commands.url_parser import generate_server_configs, parse_server_urls from mcp_agent.core.fastagent import FastAgent app = typer.Typer( help="Run an interactive agent directly from the command line without creating an agent.py file" ) + async def _run_agent( name: str = "FastAgent CLI", instruction: str = "You are a helpful AI Agent.", @@ -19,33 +21,61 @@ async def _run_agent( server_list: Optional[List[str]] = None, model: Optional[str] = None, message: Optional[str] = None, - prompt_file: Optional[str] = None + prompt_file: Optional[str] = None, + url_servers: Optional[Dict[str, Dict[str, str]]] = None, ) -> None: """Async implementation to run an interactive agent.""" from pathlib import Path + from mcp_agent.config import MCPServerSettings, MCPSettings from mcp_agent.mcp.prompts.prompt_load import load_prompt_multipart - # Create the FastAgent instance with CLI arg parsing enabled - # It will automatically parse args like --model, --quiet, etc. + # Create the FastAgent instance fast_kwargs = { "name": name, "config_path": config_path, "ignore_unknown_args": True, "parse_cli_args": False, # Don't parse CLI args, we're handling it ourselves } - + fast = FastAgent(**fast_kwargs) + # Add URL-based servers to the context configuration + if url_servers: + # Initialize the app to ensure context is ready + await fast.app.initialize() + + # Initialize mcp settings if needed + if not hasattr(fast.app.context.config, "mcp"): + fast.app.context.config.mcp = MCPSettings() + + # Initialize servers dictionary if needed + if ( + not hasattr(fast.app.context.config.mcp, "servers") + or fast.app.context.config.mcp.servers is None + ): + fast.app.context.config.mcp.servers = {} + + # Add each URL server to the config + for server_name, server_config in url_servers.items(): + server_settings = {"transport": server_config["transport"], "url": server_config["url"]} + + # Add headers if present in the server config + if "headers" in server_config: + server_settings["headers"] = server_config["headers"] + + fast.app.context.config.mcp.servers[server_name] = MCPServerSettings(**server_settings) + # Define the agent with specified parameters agent_kwargs = {"instruction": instruction} if server_list: agent_kwargs["servers"] = server_list if model: agent_kwargs["model"] = model - + # Handle prompt file and message options if message or prompt_file: + @fast.agent(**agent_kwargs) async def cli_agent(): async with fast.run() as agent: @@ -55,7 +85,7 @@ async def cli_agent(): print(response) elif prompt_file: prompt = load_prompt_multipart(Path(prompt_file)) - response = await agent.generate(prompt) + response = await agent.default.generate(prompt) # Print the response text and exit print(response.last_text()) else: @@ -68,18 +98,37 @@ async def cli_agent(): # Run the agent await cli_agent() + def run_async_agent( - name: str, - instruction: str, - config_path: Optional[str] = None, + name: str, + instruction: str, + config_path: Optional[str] = None, servers: Optional[str] = None, + urls: Optional[str] = None, + auth: Optional[str] = None, model: Optional[str] = None, message: Optional[str] = None, - prompt_file: Optional[str] = None + prompt_file: Optional[str] = None, ): """Run the async agent function with proper loop handling.""" - server_list = servers.split(',') if servers else None - + server_list = servers.split(",") if servers else None + + # Parse URLs and generate server configurations if provided + url_servers = None + if urls: + try: + parsed_urls = parse_server_urls(urls, auth) + url_servers = generate_server_configs(parsed_urls) + # If we have servers from URLs, add their names to the server_list + if url_servers and not server_list: + server_list = list(url_servers.keys()) + elif url_servers and server_list: + # Merge both lists + server_list.extend(list(url_servers.keys())) + except ValueError as e: + print(f"Error parsing URLs: {e}") + return + # Check if we're already in an event loop try: loop = asyncio.get_event_loop() @@ -92,24 +141,27 @@ def run_async_agent( # No event loop exists, so we'll create one loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - + try: - loop.run_until_complete(_run_agent( - name=name, - instruction=instruction, - config_path=config_path, - server_list=server_list, - model=model, - message=message, - prompt_file=prompt_file - )) + loop.run_until_complete( + _run_agent( + name=name, + instruction=instruction, + config_path=config_path, + server_list=server_list, + model=model, + message=message, + prompt_file=prompt_file, + url_servers=url_servers, + ) + ) finally: try: # Clean up the loop tasks = asyncio.all_tasks(loop) for task in tasks: task.cancel() - + # Run the event loop until all tasks are done if sys.version_info >= (3, 7): loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) @@ -118,6 +170,7 @@ def run_async_agent( except Exception: pass + @app.callback(invoke_without_command=True) def go( ctx: typer.Context, @@ -131,6 +184,12 @@ def go( servers: Optional[str] = typer.Option( None, "--servers", help="Comma-separated list of server names to enable from config" ), + urls: Optional[str] = typer.Option( + None, "--url", help="Comma-separated list of HTTP/SSE URLs to connect to" + ), + auth: Optional[str] = typer.Option( + None, "--auth", help="Bearer token for authorization with URL-based servers" + ), model: Optional[str] = typer.Option( None, "--model", help="Override the default model (e.g., haiku, sonnet, gpt-4)" ), @@ -148,6 +207,8 @@ def go( fast-agent go --model=haiku --instruction="You are a coding assistant" --servers=fetch,filesystem fast-agent go --message="What is the weather today?" --model=haiku fast-agent go --prompt-file=my-prompt.txt --model=haiku + fast-agent go --url=http://localhost:8001/mcp,http://api.example.com/sse + fast-agent go --url=https://api.example.com/mcp --auth=YOUR_API_TOKEN This will start an interactive session with the agent, using the specified model and instruction. It will use the default configuration from fastagent.config.yaml @@ -157,15 +218,19 @@ def go( --model Override the default model (e.g., --model=haiku) --quiet Disable progress display and logging --servers Comma-separated list of server names to enable from config + --url Comma-separated list of HTTP/SSE URLs to connect to + --auth Bearer token for authorization with URL-based servers --message, -m Send a single message and exit --prompt-file, -p Use a prompt file instead of interactive mode """ run_async_agent( - name=name, - instruction=instruction, - config_path=config_path, + name=name, + instruction=instruction, + config_path=config_path, servers=servers, + urls=urls, + auth=auth, model=model, message=message, - prompt_file=prompt_file - ) \ No newline at end of file + prompt_file=prompt_file, + ) diff --git a/src/mcp_agent/cli/commands/url_parser.py b/src/mcp_agent/cli/commands/url_parser.py new file mode 100644 index 000000000..fc2fb525a --- /dev/null +++ b/src/mcp_agent/cli/commands/url_parser.py @@ -0,0 +1,185 @@ +""" +URL parsing utility for the fast-agent CLI. +Provides functions to parse URLs and determine MCP server configurations. +""" + +import hashlib +import re +from typing import Dict, List, Literal, Tuple +from urllib.parse import urlparse + + +def parse_server_url( + url: str, +) -> Tuple[str, Literal["http", "sse"], str]: + """ + Parse a server URL and determine the transport type and server name. + + Args: + url: The URL to parse + + Returns: + Tuple containing: + - server_name: A generated name for the server + - transport_type: Either "http" or "sse" based on URL + - url: The parsed and validated URL + + Raises: + ValueError: If the URL is invalid or unsupported + """ + # Basic URL validation + if not url: + raise ValueError("URL cannot be empty") + + # Parse the URL + parsed_url = urlparse(url) + + # Ensure scheme is present and is either http or https + if not parsed_url.scheme or parsed_url.scheme not in ("http", "https"): + raise ValueError(f"URL must have http or https scheme: {url}") + + # Ensure netloc (hostname) is present + if not parsed_url.netloc: + raise ValueError(f"URL must include a hostname: {url}") + + # Determine transport type based on URL path + transport_type: Literal["http", "sse"] = "http" + if parsed_url.path.endswith("/sse"): + transport_type = "sse" + elif not parsed_url.path.endswith("/mcp"): + # If path doesn't end with /mcp or /sse, append /mcp + url = url if url.endswith("/") else f"{url}/" + url = f"{url}mcp" + + # Generate a server name based on hostname and port + server_name = generate_server_name(url) + + return server_name, transport_type, url + + +def generate_server_name(url: str) -> str: + """ + Generate a unique and readable server name from a URL. + + Args: + url: The URL to generate a name for + + Returns: + A server name string + """ + parsed_url = urlparse(url) + + # Extract hostname and port + hostname = parsed_url.netloc.split(":")[0] + + # Clean the hostname for use in a server name + # Replace non-alphanumeric characters with underscores + clean_hostname = re.sub(r"[^a-zA-Z0-9]", "_", hostname) + + if len(clean_hostname) > 15: + clean_hostname = clean_hostname[:9] + clean_hostname[-5:] + + # If it's localhost or an IP, add a more unique identifier + if clean_hostname in ("localhost", "127_0_0_1") or re.match(r"^(\d+_){3}\d+$", clean_hostname): + # Use the path as part of the name for uniqueness + path = parsed_url.path.strip("/") + path = re.sub(r"[^a-zA-Z0-9]", "_", path) + + # Include port if specified + port = "" + if ":" in parsed_url.netloc: + port = f"_{parsed_url.netloc.split(':')[1]}" + + if path: + return f"{clean_hostname}{port}_{path[:20]}" # Limit path length + else: + # Use a hash if no path for uniqueness + url_hash = hashlib.md5(url.encode()).hexdigest()[:8] + return f"{clean_hostname}{port}_{url_hash}" + + return clean_hostname + + +def parse_server_urls( + urls_param: str, auth_token: str = None +) -> List[Tuple[str, Literal["http", "sse"], str, Dict[str, str] | None]]: + """ + Parse a comma-separated list of URLs into server configurations. + + Args: + urls_param: Comma-separated list of URLs + auth_token: Optional bearer token for authorization + + Returns: + List of tuples containing (server_name, transport_type, url, headers) + + Raises: + ValueError: If any URL is invalid + """ + if not urls_param: + return [] + + # Split by comma and strip whitespace + url_list = [url.strip() for url in urls_param.split(",")] + + # Prepare headers if auth token is provided + headers = None + if auth_token: + headers = {"Authorization": f"Bearer {auth_token}"} + + # Parse each URL + result = [] + for url in url_list: + server_name, transport_type, parsed_url = parse_server_url(url) + result.append((server_name, transport_type, parsed_url, headers)) + + return result + + +def generate_server_configs( + parsed_urls: List[Tuple[str, Literal["http", "sse"], str, Dict[str, str] | None]], +) -> Dict[str, Dict[str, str | Dict[str, str]]]: + """ + Generate server configurations from parsed URLs. + + Args: + parsed_urls: List of tuples containing (server_name, transport_type, url, headers) + + Returns: + Dictionary of server configurations + """ + server_configs = {} + # Keep track of server name occurrences to handle collisions + name_counts = {} + + for server_name, transport_type, url, headers in parsed_urls: + # Handle name collisions by adding a suffix + final_name = server_name + if server_name in server_configs: + # Initialize counter if we haven't seen this name yet + if server_name not in name_counts: + name_counts[server_name] = 1 + + # Generate a new name with suffix + suffix = name_counts[server_name] + final_name = f"{server_name}_{suffix}" + name_counts[server_name] += 1 + + # Ensure the new name is also unique + while final_name in server_configs: + suffix = name_counts[server_name] + final_name = f"{server_name}_{suffix}" + name_counts[server_name] += 1 + + config = { + "transport": transport_type, + "url": url, + } + + # Add headers if provided + if headers: + config["headers"] = headers + + server_configs[final_name] = config + + return server_configs diff --git a/src/mcp_agent/config.py b/src/mcp_agent/config.py index 30881a088..9f93a038a 100644 --- a/src/mcp_agent/config.py +++ b/src/mcp_agent/config.py @@ -60,7 +60,7 @@ class MCPServerSettings(BaseModel): description: str | None = None """The description of the server.""" - transport: Literal["stdio", "sse"] = "stdio" + transport: Literal["stdio", "sse", "http"] = "stdio" """The transport mechanism.""" command: str | None = None diff --git a/src/mcp_agent/core/fastagent.py b/src/mcp_agent/core/fastagent.py index b83aced14..893bdc203 100644 --- a/src/mcp_agent/core/fastagent.py +++ b/src/mcp_agent/core/fastagent.py @@ -131,8 +131,8 @@ def __init__( ) parser.add_argument( "--transport", - choices=["sse", "stdio"], - default="sse", + choices=["sse", "http", "stdio"], + default="http", help="Transport protocol to use when running as a server (sse or stdio)", ) parser.add_argument( diff --git a/src/mcp_agent/core/request_params.py b/src/mcp_agent/core/request_params.py index ea84d476f..7b087829b 100644 --- a/src/mcp_agent/core/request_params.py +++ b/src/mcp_agent/core/request_params.py @@ -25,24 +25,23 @@ class RequestParams(CreateMessageRequestParams): model: str | None = None """ - The model to use for the LLM generation. + The model to use for the LLM generation. This can only be set during Agent creation. If specified, this overrides the 'modelPreferences' selection criteria. """ use_history: bool = True """ - Include the message history in the generate request. + Agent/LLM maintains conversation history. Does not include applied Prompts """ - max_iterations: int = 10 + max_iterations: int = 20 """ - The maximum number of iterations to run the LLM for. + The maximum number of tool calls allowed in a conversation turn """ parallel_tool_calls: bool = True """ - Whether to allow multiple tool calls per iteration. - Also known as multi-step tool use. + Whether to allow simultaneous tool calls """ response_format: Any | None = None """ diff --git a/src/mcp_agent/mcp/mcp_connection_manager.py b/src/mcp_agent/mcp/mcp_connection_manager.py index f7f02b6c1..f89ad9369 100644 --- a/src/mcp_agent/mcp/mcp_connection_manager.py +++ b/src/mcp_agent/mcp/mcp_connection_manager.py @@ -23,6 +23,7 @@ get_default_environment, stdio_client, ) +from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client from mcp.types import JSONRPCMessage, ServerCapabilities from mcp_agent.config import MCPServerSettings @@ -40,6 +41,27 @@ logger = get_logger(__name__) +class StreamingContextAdapter: + """Adapter to provide a 3-value context from a 2-value context manager""" + + def __init__(self, context_manager): + self.context_manager = context_manager + self.cm_instance = None + + async def __aenter__(self): + self.cm_instance = await self.context_manager.__aenter__() + read_stream, write_stream = self.cm_instance + return read_stream, write_stream, None + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return await self.context_manager.__aexit__(exc_type, exc_val, exc_tb) + + +def _add_none_to_context(context_manager): + """Helper to add a None value to context managers that return 2 values instead of 3""" + return StreamingContextAdapter(context_manager) + + class ServerConnection: """ Represents a long-lived MCP server connection, including: @@ -57,6 +79,7 @@ def __init__( tuple[ MemoryObjectReceiveStream[JSONRPCMessage | Exception], MemoryObjectSendStream[JSONRPCMessage], + GetSessionIdCallback | None, ], None, ], @@ -162,7 +185,7 @@ async def _server_lifecycle_task(server_conn: ServerConnection) -> None: try: transport_context = server_conn._transport_context_factory() - async with transport_context as (read_stream, write_stream): + async with transport_context as (read_stream, write_stream, _): server_conn.create_session(read_stream, write_stream) async with server_conn.session: @@ -303,14 +326,17 @@ def transport_context_factory(): error_handler = get_stderr_handler(server_name) # Explicitly ensure we're using our custom logger for stderr logger.debug(f"{server_name}: Creating stdio client with custom error handler") - return stdio_client(server_params, errlog=error_handler) + return _add_none_to_context(stdio_client(server_params, errlog=error_handler)) elif config.transport == "sse": - return sse_client( - config.url, - config.headers, - sse_read_timeout=config.read_transport_sse_timeout_seconds, + return _add_none_to_context( + sse_client( + config.url, + config.headers, + sse_read_timeout=config.read_transport_sse_timeout_seconds, + ) ) - + elif config.transport == "http": + return streamablehttp_client(config.url, config.headers) else: raise ValueError(f"Unsupported transport: {config.transport}") diff --git a/src/mcp_agent/mcp/prompts/prompt_server.py b/src/mcp_agent/mcp/prompts/prompt_server.py index 561cfe94d..2756049e4 100644 --- a/src/mcp_agent/mcp/prompts/prompt_server.py +++ b/src/mcp_agent/mcp/prompts/prompt_server.py @@ -335,7 +335,7 @@ def parse_args(): parser.add_argument( "--transport", type=str, - choices=["stdio", "sse"], + choices=["stdio", "sse", "http"], default="stdio", help="Transport to use (default: stdio)", ) @@ -502,14 +502,22 @@ async def async_main() -> int: return await test_prompt(args.test, config) # Start the server with the specified transport - if config.transport == "stdio": - await mcp.run_stdio_async() - else: # sse + if config.transport == "sse": # sse # Set the host and port in settings before running the server mcp.settings.host = config.host mcp.settings.port = config.port logger.info(f"Starting SSE server on {config.host}:{config.port}") await mcp.run_sse_async() + elif config.transport == "http": + mcp.settings.host = config.host + mcp.settings.port = config.port + logger.info(f"Starting SSE server on {config.host}:{config.port}") + await mcp.run_streamable_http_async() + elif config.transport == "stdio": + await mcp.run_stdio_async() + else: + logger.error(f"Unknown transport: {config.transport}") + return 1 return 0 diff --git a/src/mcp_agent/mcp_server/agent_server.py b/src/mcp_agent/mcp_server/agent_server.py index 36b3cc2ca..a0e1f9723 100644 --- a/src/mcp_agent/mcp_server/agent_server.py +++ b/src/mcp_agent/mcp_server/agent_server.py @@ -140,9 +140,9 @@ async def _handle_shutdown_signal(self, is_term=False): print("Press Ctrl+C again to force exit.") self._graceful_shutdown_event.set() - def run(self, transport: str = "sse", host: str = "0.0.0.0", port: int = 8000) -> None: + def run(self, transport: str = "http", host: str = "0.0.0.0", port: int = 8000) -> None: """Run the MCP server synchronously.""" - if transport == "sse": + if transport in ["sse", "http"]: self.mcp_server.settings.host = host self.mcp_server.settings.port = port @@ -180,12 +180,12 @@ def run(self, transport: str = "sse", host: str = "0.0.0.0", port: int = 8000) - asyncio.run(self._cleanup_stdio()) async def run_async( - self, transport: str = "sse", host: str = "0.0.0.0", port: int = 8000 + self, transport: str = "http", host: str = "0.0.0.0", port: int = 8000 ) -> None: """Run the MCP server asynchronously with improved shutdown handling.""" # Use different handling strategies based on transport type - if transport == "sse": - # For SSE, use our enhanced shutdown handling + if transport in ["sse", "http"]: + # For SSE/HTTP, use our enhanced shutdown handling self._setup_signal_handlers() self.mcp_server.settings.host = host @@ -236,9 +236,9 @@ async def run_async( async def _run_server_with_shutdown(self, transport: str): """Run the server with proper shutdown handling.""" - # This method is only used for SSE transport - if transport != "sse": - raise ValueError("This method should only be used with SSE transport") + # This method is used for SSE/HTTP transport + if transport not in ["sse", "http"]: + raise ValueError("This method should only be used with SSE or HTTP transport") # Start a monitor task for shutdown shutdown_monitor = asyncio.create_task(self._monitor_shutdown()) @@ -262,8 +262,11 @@ async def tracked_connect_sse(*args, **kwargs): # Replace with our tracking version self.mcp_server._sse_transport.connect_sse = tracked_connect_sse - # Run the server (SSE only) - await self.mcp_server.run_sse_async() + # Run the server based on transport type + if transport == "sse": + await self.mcp_server.run_sse_async() + elif transport == "http": + await self.mcp_server.run_streamable_http_async() finally: # Cancel the monitor when the server exits shutdown_monitor.cancel() diff --git a/src/mcp_agent/mcp_server_registry.py b/src/mcp_agent/mcp_server_registry.py index da431f4f1..ddea78f58 100644 --- a/src/mcp_agent/mcp_server_registry.py +++ b/src/mcp_agent/mcp_server_registry.py @@ -18,6 +18,7 @@ StdioServerParameters, get_default_environment, ) +from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client from mcp_agent.config import ( MCPServerAuthSettings, @@ -27,7 +28,10 @@ ) from mcp_agent.logging.logger import get_logger from mcp_agent.mcp.logger_textio import get_stderr_handler -from mcp_agent.mcp.mcp_connection_manager import MCPConnectionManager +from mcp_agent.mcp.mcp_connection_manager import ( + MCPConnectionManager, + _add_none_to_context, +) logger = get_logger(__name__) @@ -93,7 +97,12 @@ async def start_server( self, server_name: str, client_session_factory: Callable[ - [MemoryObjectReceiveStream, MemoryObjectSendStream, timedelta | None], + [ + MemoryObjectReceiveStream, + MemoryObjectSendStream, + timedelta | None, + GetSessionIdCallback | None, + ], ClientSession, ] = ClientSession, ) -> AsyncGenerator[ClientSession, None]: @@ -132,14 +141,18 @@ async def start_server( ) # Create a stderr handler that logs to our application logger - async with stdio_client(server_params, errlog=get_stderr_handler(server_name)) as ( + async with _add_none_to_context( + stdio_client(server_params, errlog=get_stderr_handler(server_name)) + ) as ( read_stream, write_stream, + _, ): session = client_session_factory( read_stream, write_stream, read_timeout_seconds, + None, # No callback for stdio ) async with session: logger.info(f"{server_name}: Connected to server using stdio transport.") @@ -153,15 +166,18 @@ async def start_server( raise ValueError(f"URL is required for SSE transport: {server_name}") # Use sse_client to get the read and write streams - async with sse_client( - config.url, - config.headers, - sse_read_timeout=config.read_transport_sse_timeout_seconds, - ) as (read_stream, write_stream): + async with _add_none_to_context( + sse_client( + config.url, + config.headers, + sse_read_timeout=config.read_transport_sse_timeout_seconds, + ) + ) as (read_stream, write_stream, _): session = client_session_factory( read_stream, write_stream, read_timeout_seconds, + None, # No callback for stdio ) async with session: logger.info(f"{server_name}: Connected to server using SSE transport.") @@ -169,6 +185,27 @@ async def start_server( yield session finally: logger.debug(f"{server_name}: Closed session to server") + elif config.transport == "http": + if not config.url: + raise ValueError(f"URL is required for SSE transport: {server_name}") + + async with streamablehttp_client(config.url, config.headers) as ( + read_stream, + write_stream, + _, + ): + session = client_session_factory( + read_stream, + write_stream, + read_timeout_seconds, + None, # No callback for stdio + ) + async with session: + logger.info(f"{server_name}: Connected to server using HTTP transport.") + try: + yield session + finally: + logger.debug(f"{server_name}: Closed session to server") # Unsupported transport else: @@ -179,7 +216,12 @@ async def initialize_server( self, server_name: str, client_session_factory: Callable[ - [MemoryObjectReceiveStream, MemoryObjectSendStream, timedelta | None], + [ + MemoryObjectReceiveStream, + MemoryObjectSendStream, + timedelta | None, + GetSessionIdCallback, + ], ClientSession, ] = ClientSession, init_hook: InitHookCallable = None, diff --git a/src/mcp_agent/resources/examples/mcp/state-transfer/fastagent.config.yaml b/src/mcp_agent/resources/examples/mcp/state-transfer/fastagent.config.yaml index 8ae2c432d..df8c9f584 100644 --- a/src/mcp_agent/resources/examples/mcp/state-transfer/fastagent.config.yaml +++ b/src/mcp_agent/resources/examples/mcp/state-transfer/fastagent.config.yaml @@ -23,5 +23,5 @@ logger: mcp: servers: agent_one: - transport: sse - url: http://localhost:8001/sse + transport: http + url: http://localhost:8001/mcp diff --git a/tests/integration/api/fastagent.config.yaml b/tests/integration/api/fastagent.config.yaml index 7953fe505..5eed30127 100644 --- a/tests/integration/api/fastagent.config.yaml +++ b/tests/integration/api/fastagent.config.yaml @@ -40,6 +40,9 @@ mcp: sse: transport: "sse" url: "http://localhost:8723/sse" + http: + transport: "http" + url: "http://localhost:8724/mcp" card_test: command: "uv" args: ["run", "mcp_tools_server.py"] diff --git a/tests/integration/api/test_cli_and_mcp_server.py b/tests/integration/api/test_cli_and_mcp_server.py index 2973da467..023668d69 100644 --- a/tests/integration/api/test_cli_and_mcp_server.py +++ b/tests/integration/api/test_cli_and_mcp_server.py @@ -185,7 +185,7 @@ async def test_agent_server_option_sse(fast_agent): try: # Give the server a moment to start - await asyncio.sleep(3) + await asyncio.sleep(2) # Now connect to it via the configured MCP server @fast_agent.agent(name="client", servers=["sse"]) @@ -206,3 +206,64 @@ async def agent_function(): server_proc.wait(timeout=2) except subprocess.TimeoutExpired: server_proc.kill() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_agent_server_option_http(fast_agent): + """Test that FastAgent supports --server flag with HTTP transport.""" + + # Start the SSE server in a subprocess + import asyncio + import os + import subprocess + + # Get the path to the test agent + test_dir = os.path.dirname(os.path.abspath(__file__)) + test_agent_path = os.path.join(test_dir, "integration_agent.py") + + # Port must match what's in the fastagent.config.yaml + port = 8724 + + # Start the server process + server_proc = subprocess.Popen( + [ + "uv", + "run", + test_agent_path, + "--server", + "--transport", + "http", + "--port", + str(port), + "--quiet", + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + cwd=test_dir, + ) + + try: + # Give the server a moment to start + await asyncio.sleep(2) + + # Now connect to it via the configured MCP server + @fast_agent.agent(name="client", servers=["http"]) + async def agent_function(): + async with fast_agent.run() as agent: + # Try connecting and sending a message + assert "connected" == await agent.send("connected") + result = await agent.send('***CALL_TOOL test_send {"message": "http server test"}') + assert "http server test" == result + + await agent_function() + + finally: + # Terminate the server process + if server_proc.poll() is None: # If still running + server_proc.terminate() + try: + server_proc.wait(timeout=2) + except subprocess.TimeoutExpired: + server_proc.kill() diff --git a/tests/integration/prompt-server/fastagent.config.yaml b/tests/integration/prompt-server/fastagent.config.yaml index 1854b754d..3dbdadda4 100644 --- a/tests/integration/prompt-server/fastagent.config.yaml +++ b/tests/integration/prompt-server/fastagent.config.yaml @@ -29,3 +29,6 @@ mcp: prompt_sse: transport: "sse" url: "http://localhost:8723/sse" + prompt_http: + transport: "http" + url: "http://localhost:8724/mcp" diff --git a/tests/integration/prompt-server/test_prompt_server_integration.py b/tests/integration/prompt-server/test_prompt_server_integration.py index ce1048885..629d93115 100644 --- a/tests/integration/prompt-server/test_prompt_server_integration.py +++ b/tests/integration/prompt-server/test_prompt_server_integration.py @@ -218,3 +218,49 @@ async def agent_function(): server_proc.wait(timeout=2) except subprocess.TimeoutExpired: server_proc.kill() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_prompt_server_http_can_set_ports(fast_agent): + # Start the SSE server in a subprocess + import asyncio + import os + import subprocess + + # Get the path to the test agent + test_dir = os.path.dirname(os.path.abspath(__file__)) + + # Port must match what's in the fastagent.config.yaml + port = 8724 + + # Start the server process + server_proc = subprocess.Popen( + ["prompt-server", "--transport", "http", "--port", str(port), "simple.txt"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + cwd=test_dir, + ) + + try: + # Give the server a moment to start + await asyncio.sleep(3) + + # Now connect to it via the configured MCP server + @fast_agent.agent(name="client", servers=["prompt_http"], model="passthrough") + async def agent_function(): + async with fast_agent.run() as agent: + # Try connecting and sending a message + assert "simple" in await agent.apply_prompt("simple") + + await agent_function() + + finally: + # Terminate the server process + if server_proc.poll() is None: # If still running + server_proc.terminate() + try: + server_proc.wait(timeout=2) + except subprocess.TimeoutExpired: + server_proc.kill() diff --git a/tests/unit/mcp_agent/cli/commands/test_url_parser.py b/tests/unit/mcp_agent/cli/commands/test_url_parser.py new file mode 100644 index 000000000..f3e6fd049 --- /dev/null +++ b/tests/unit/mcp_agent/cli/commands/test_url_parser.py @@ -0,0 +1,180 @@ +""" +Unit tests for the URL parser utility functions. +""" + +import pytest + +from mcp_agent.cli.commands.url_parser import ( + generate_server_configs, + generate_server_name, + parse_server_url, + parse_server_urls, +) + + +class TestUrlParser: + """Tests for URL parsing utilities.""" + + def test_parse_server_url_valid(self): + """Test parsing valid URLs.""" + # HTTP URL ending with /mcp + server_name, transport, url = parse_server_url("http://example.com/mcp") + assert server_name == "example_com" + assert transport == "http" + assert url == "http://example.com/mcp" + + # HTTPS URL ending with /sse + server_name, transport, url = parse_server_url("https://api.test.com/sse") + assert server_name == "api_test_com" + assert transport == "sse" + assert url == "https://api.test.com/sse" + + # URL without /mcp or /sse should append /mcp + server_name, transport, url = parse_server_url("http://localhost:8080/api") + assert transport == "http" + assert url == "http://localhost:8080/api/mcp" + + def test_parse_server_url_invalid(self): + """Test parsing invalid URLs.""" + # Empty URL + with pytest.raises(ValueError, match="URL cannot be empty"): + parse_server_url("") + + # Missing scheme + with pytest.raises(ValueError, match="URL must have http or https scheme"): + parse_server_url("example.com/mcp") + + # Invalid scheme + with pytest.raises(ValueError, match="URL must have http or https scheme"): + parse_server_url("ftp://example.com/mcp") + + # Missing hostname + with pytest.raises(ValueError, match="URL must include a hostname"): + parse_server_url("http:///mcp") + + def test_generate_server_name(self): + """Test server name generation from URLs.""" + # Standard domain + assert generate_server_name("http://example.com/mcp") == "example_com" + + # Domain with subdomain + assert generate_server_name("https://api.example.com/mcp") == "api_example_com" + + # Localhost with port + name = generate_server_name("http://localhost:8080/mcp") + assert name.startswith("localhost_8080_") + + # IP address + name = generate_server_name("http://192.168.1.1/api/mcp") + assert name.startswith("192_168_1_1_") + assert "api_mcp" in name or len(name.split("_")) > 4 + + # Long domain name + name = generate_server_name("http://very.long.domain.name:14432/api/someendpoint/mcp") + assert "very_long_name" == name + + def test_parse_server_urls(self): + """Test parsing multiple URLs.""" + urls = "http://example.com/mcp,https://api.test.com/sse,http://localhost:8080/api" + result = parse_server_urls(urls) + + assert len(result) == 3 + + # First URL + assert result[0][0] == "example_com" + assert result[0][1] == "http" + assert result[0][2] == "http://example.com/mcp" + assert result[0][3] is None # No auth headers + + # Second URL + assert result[1][0] == "api_test_com" + assert result[1][1] == "sse" + assert result[1][2] == "https://api.test.com/sse" + assert result[1][3] is None # No auth headers + + # Third URL + assert result[2][1] == "http" + assert result[2][2] == "http://localhost:8080/api/mcp" + assert result[2][3] is None # No auth headers + + # Empty input + assert parse_server_urls("") == [] + + def test_parse_server_urls_with_auth(self): + """Test parsing URLs with authentication token.""" + urls = "http://example.com/mcp,https://api.test.com/sse" + auth_token = "test_token_123" + result = parse_server_urls(urls, auth_token) + + assert len(result) == 2 + + # All URLs should have auth headers + for server_name, transport, url, headers in result: + assert headers is not None + assert headers == {"Authorization": "Bearer test_token_123"} + + def test_generate_server_configs(self): + """Test generating server configurations from parsed URLs.""" + parsed_urls = [ + ("example_com", "http", "http://example.com/mcp", None), + ("api_test_com", "sse", "https://api.test.com/sse", None), + ] + + configs = generate_server_configs(parsed_urls) + + assert len(configs) == 2 + + assert configs["example_com"]["transport"] == "http" + assert configs["example_com"]["url"] == "http://example.com/mcp" + assert "headers" not in configs["example_com"] + + assert configs["api_test_com"]["transport"] == "sse" + assert configs["api_test_com"]["url"] == "https://api.test.com/sse" + assert "headers" not in configs["api_test_com"] + + def test_generate_server_configs_with_auth(self): + """Test generating server configurations with auth headers.""" + auth_headers = {"Authorization": "Bearer test_token_123"} + parsed_urls = [ + ("example_com", "http", "http://example.com/mcp", auth_headers), + ("api_test_com", "sse", "https://api.test.com/sse", auth_headers), + ] + + configs = generate_server_configs(parsed_urls) + + assert len(configs) == 2 + + # Check both configs have headers + for server_name, config in configs.items(): + assert "headers" in config + assert config["headers"] == auth_headers + + def test_generate_server_configs_with_name_collisions(self): + """Test handling of server name collisions.""" + # Create a list of parsed URLs with the same server name + parsed_urls = [ + ( + "evalstate", + "sse", + "https://evalstate-parler-tts-expresso.hf.space/gradio_api/mcp/sse", + None, + ), + ("evalstate", "sse", "https://evalstate-shuttle.hf.space/gradio_api/mcp/sse", None), + ("evalstate", "http", "https://evalstate-another.hf.space/gradio_api/mcp", None), + ] + + configs = generate_server_configs(parsed_urls) + + # Should still have 3 configs despite name collisions + assert len(configs) == 3 + + # Should have unique keys by adding suffixes + expected_keys = {"evalstate", "evalstate_1", "evalstate_2"} + assert set(configs.keys()) == expected_keys + + # Check that URLs are preserved correctly + urls = {config["url"] for config in configs.values()} + assert len(urls) == 3 + assert "https://evalstate-parler-tts-expresso.hf.space/gradio_api/mcp/sse" in urls + assert "https://evalstate-shuttle.hf.space/gradio_api/mcp/sse" in urls + assert "https://evalstate-another.hf.space/gradio_api/mcp" in urls