diff --git a/src/fast_agent/config.py b/src/fast_agent/config.py index 58fad7b2..cd2c3f07 100644 --- a/src/fast_agent/config.py +++ b/src/fast_agent/config.py @@ -39,8 +39,35 @@ class MCPServerAuthSettings(BaseModel): # Token persistence: use OS keychain via 'keyring' by default; fallback to 'memory'. persist: Literal["keyring", "memory"] = "keyring" + # Client ID Metadata Document (CIMD) URL. + # When provided and the server advertises client_id_metadata_document_supported=true, + # this URL will be used as the client_id instead of performing dynamic client registration. + # Must be a valid HTTPS URL with a non-root pathname (e.g., https://example.com/client.json). + # See: https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization + client_metadata_url: str | None = None + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) + @field_validator("client_metadata_url", mode="after") + @classmethod + def _validate_client_metadata_url(cls, v: str | None) -> str | None: + """Validate that client_metadata_url is a valid HTTPS URL with a non-root path.""" + if v is None: + return None + from urllib.parse import urlparse + + try: + parsed = urlparse(v) + if parsed.scheme != "https": + raise ValueError("client_metadata_url must use HTTPS scheme") + if parsed.path in ("", "/"): + raise ValueError("client_metadata_url must have a non-root pathname") + return v + except ValueError: + raise + except Exception as e: + raise ValueError(f"Invalid client_metadata_url: {e}") + class MCPSamplingSettings(BaseModel): model: str = "gpt-5-mini.low" diff --git a/src/fast_agent/mcp/oauth_client.py b/src/fast_agent/mcp/oauth_client.py index 7b9bc8a7..3a60d784 100644 --- a/src/fast_agent/mcp/oauth_client.py +++ b/src/fast_agent/mcp/oauth_client.py @@ -114,14 +114,28 @@ def log_message(self, format: str, *args: Any) -> None: # silence default loggi class _CallbackServer: - """Simple background HTTP server to receive a single OAuth callback.""" + """Simple background HTTP server to receive a single OAuth callback. + + Uses 127.0.0.1 (loopback IP) instead of localhost for RFC 8252 compliance. + Per RFC 8252 Section 7.3, authorization servers MUST allow any port for + loopback IP redirect URIs, enabling dynamic port allocation. + """ + + # Fallback ports to try if preferred port is unavailable + FALLBACK_PORTS = [3030, 3031, 3032, 8080, 0] # 0 = ephemeral port def __init__(self, port: int, path: str) -> None: - self._port = port + self._preferred_port = port self._path = path.rstrip("/") or "/callback" self._result = _CallbackResult() self._server: HTTPServer | None = None self._thread: threading.Thread | None = None + self._actual_port: int | None = None + + @property + def actual_port(self) -> int | None: + """Return the actual port the server bound to (may differ from preferred).""" + return self._actual_port def _make_handler(self) -> Callable[..., BaseHTTPRequestHandler]: result = self._result @@ -132,11 +146,46 @@ def handler(*args, **kwargs): return handler + def _try_bind(self, port: int) -> HTTPServer | None: + """Try to bind to the given port. Returns server if successful, None otherwise.""" + try: + # Use 127.0.0.1 (loopback IP) for RFC 8252 compliance + server = HTTPServer(("127.0.0.1", port), self._make_handler()) + return server + except OSError as e: + # EADDRINUSE (98 on Linux, 48 on macOS) or similar + logger.debug(f"Port {port} unavailable: {e}") + return None + def start(self) -> None: - self._server = HTTPServer(("localhost", self._port), self._make_handler()) - self._thread = threading.Thread(target=self._server.serve_forever, daemon=True) - self._thread.start() - logger.info(f"OAuth callback server listening on http://localhost:{self._port}{self._path}") + """Start the callback server, trying fallback ports if preferred is unavailable.""" + # Build list of ports to try: preferred first, then fallbacks + ports_to_try = [self._preferred_port] + for p in self.FALLBACK_PORTS: + if p not in ports_to_try: + ports_to_try.append(p) + + for port in ports_to_try: + server = self._try_bind(port) + if server is not None: + self._server = server + # Get actual port (important when using ephemeral port 0) + self._actual_port = self._server.server_address[1] + self._thread = threading.Thread(target=self._server.serve_forever, daemon=True) + self._thread.start() + logger.info( + f"OAuth callback server listening on http://127.0.0.1:{self._actual_port}{self._path}" + ) + if self._actual_port != self._preferred_port: + logger.info( + f"Note: Using port {self._actual_port} instead of preferred port {self._preferred_port}" + ) + return + + raise OSError( + f"Could not bind to any port. Tried: {ports_to_try}. " + "All ports may be in use." + ) def stop(self) -> None: if self._server: @@ -158,6 +207,12 @@ def wait(self, timeout_seconds: int = 300) -> tuple[str, str | None]: time.sleep(0.1) raise TimeoutError("Timeout waiting for OAuth callback") + def get_redirect_uri(self) -> str: + """Return the actual redirect URI based on bound port.""" + if self._actual_port is None: + raise RuntimeError("Server not started; cannot determine redirect URI") + return f"http://127.0.0.1:{self._actual_port}{self._path}" + def _derive_base_server_url(url: str | None) -> str | None: """Derive the base server URL for OAuth discovery from an MCP endpoint URL. @@ -417,6 +472,7 @@ def build_oauth_provider(server_config: MCPServerSettings) -> OAuthClientProvide redirect_path = "/callback" scope_value: str | None = None persist_mode: str = "keyring" + client_metadata_url: str | None = None if server_config.auth is not None: try: @@ -425,6 +481,7 @@ def build_oauth_provider(server_config: MCPServerSettings) -> OAuthClientProvide redirect_path = getattr(server_config.auth, "redirect_path", "/callback") scope_field = getattr(server_config.auth, "scope", None) persist_mode = getattr(server_config.auth, "persist", "keyring") + client_metadata_url = getattr(server_config.auth, "client_metadata_url", None) if isinstance(scope_field, list): scope_value = " ".join(scope_field) elif isinstance(scope_field, str): @@ -440,11 +497,23 @@ def build_oauth_provider(server_config: MCPServerSettings) -> OAuthClientProvide # No usable URL -> cannot build provider return None - # Construct client metadata with minimal defaults - redirect_uri = f"http://localhost:{redirect_port}{redirect_path}" + # Construct client metadata with minimal defaults. + # Use 127.0.0.1 (loopback IP) for RFC 8252 compliance. Per RFC 8252 Section 7.3, + # authorization servers MUST allow any port for loopback IP redirect URIs. + # We register multiple redirect URIs to support port fallback for servers that + # don't fully implement RFC 8252's dynamic port matching. + redirect_uris: list[AnyUrl] = [] + # Build list of ports: preferred first, then fallbacks + ports_for_registration = [redirect_port] + for p in _CallbackServer.FALLBACK_PORTS: + if p != 0 and p not in ports_for_registration: # Skip ephemeral port (0) + ports_for_registration.append(p) + for port in ports_for_registration: + redirect_uris.append(AnyUrl(f"http://127.0.0.1:{port}{redirect_path}")) + metadata_kwargs: dict[str, Any] = { "client_name": "fast-agent", - "redirect_uris": [AnyUrl(redirect_uri)], + "redirect_uris": redirect_uris, "grant_types": ["authorization_code", "refresh_token"], "response_types": ["code"], } @@ -504,6 +573,7 @@ async def _callback_handler() -> tuple[str, str | None]: storage=storage, redirect_handler=_redirect_handler, callback_handler=_callback_handler, + client_metadata_url=client_metadata_url, ) return provider diff --git a/tests/unit/fast_agent/mcp/test_cimd.py b/tests/unit/fast_agent/mcp/test_cimd.py new file mode 100644 index 00000000..0855364d --- /dev/null +++ b/tests/unit/fast_agent/mcp/test_cimd.py @@ -0,0 +1,293 @@ +"""Tests for Client ID Metadata Document (CIMD) support.""" + +import socket + +import pytest +from pydantic import ValidationError + +from fast_agent.config import MCPServerAuthSettings, MCPServerSettings +from fast_agent.mcp.oauth_client import _CallbackServer, build_oauth_provider + + +class TestCIMDConfigValidation: + """Test CIMD URL validation in MCPServerAuthSettings.""" + + def test_valid_cimd_url(self): + """A valid HTTPS URL with non-root path should be accepted.""" + auth = MCPServerAuthSettings( + client_metadata_url="https://example.com/client.json" + ) + assert auth.client_metadata_url == "https://example.com/client.json" + + def test_valid_cimd_url_with_path(self): + """A valid HTTPS URL with a deep path should be accepted.""" + auth = MCPServerAuthSettings( + client_metadata_url="https://example.com/oauth/client-metadata.json" + ) + assert auth.client_metadata_url == "https://example.com/oauth/client-metadata.json" + + def test_cimd_url_rejects_http(self): + """HTTP URLs should be rejected (must be HTTPS).""" + with pytest.raises(ValidationError) as exc_info: + MCPServerAuthSettings( + client_metadata_url="http://example.com/client.json" + ) + assert "client_metadata_url must use HTTPS scheme" in str(exc_info.value) + + def test_cimd_url_rejects_root_path(self): + """URLs with root path (/) should be rejected.""" + with pytest.raises(ValidationError) as exc_info: + MCPServerAuthSettings( + client_metadata_url="https://example.com/" + ) + assert "client_metadata_url must have a non-root pathname" in str(exc_info.value) + + def test_cimd_url_rejects_no_path(self): + """URLs with no path should be rejected.""" + with pytest.raises(ValidationError) as exc_info: + MCPServerAuthSettings( + client_metadata_url="https://example.com" + ) + assert "client_metadata_url must have a non-root pathname" in str(exc_info.value) + + def test_cimd_url_none_by_default(self): + """client_metadata_url should be None by default.""" + auth = MCPServerAuthSettings() + assert auth.client_metadata_url is None + + +class TestCIMDOAuthProvider: + """Test that CIMD URL is passed to OAuthClientProvider.""" + + def test_build_oauth_provider_with_cimd_url(self, monkeypatch): + """build_oauth_provider should pass client_metadata_url to OAuthClientProvider.""" + captured_kwargs = {} + + class MockOAuthClientProvider: + def __init__(self, **kwargs): + captured_kwargs.update(kwargs) + + monkeypatch.setattr( + "fast_agent.mcp.oauth_client.OAuthClientProvider", + MockOAuthClientProvider, + ) + + auth = MCPServerAuthSettings( + client_metadata_url="https://example.com/client.json" + ) + config = MCPServerSettings( + name="test", + transport="http", + url="https://example.com/mcp", + auth=auth, + ) + + build_oauth_provider(config) + + assert captured_kwargs.get("client_metadata_url") == "https://example.com/client.json" + + def test_build_oauth_provider_without_cimd_url(self, monkeypatch): + """build_oauth_provider should pass None for client_metadata_url when not configured.""" + captured_kwargs = {} + + class MockOAuthClientProvider: + def __init__(self, **kwargs): + captured_kwargs.update(kwargs) + + monkeypatch.setattr( + "fast_agent.mcp.oauth_client.OAuthClientProvider", + MockOAuthClientProvider, + ) + + config = MCPServerSettings( + name="test", + transport="http", + url="https://example.com/mcp", + ) + + build_oauth_provider(config) + + assert captured_kwargs.get("client_metadata_url") is None + + def test_build_oauth_provider_cimd_with_sse_transport(self, monkeypatch): + """build_oauth_provider should work with SSE transport and CIMD.""" + captured_kwargs = {} + + class MockOAuthClientProvider: + def __init__(self, **kwargs): + captured_kwargs.update(kwargs) + + monkeypatch.setattr( + "fast_agent.mcp.oauth_client.OAuthClientProvider", + MockOAuthClientProvider, + ) + + auth = MCPServerAuthSettings( + client_metadata_url="https://example.com/client.json" + ) + config = MCPServerSettings( + name="test", + transport="sse", + url="https://example.com/sse", + auth=auth, + ) + + build_oauth_provider(config) + + assert captured_kwargs.get("client_metadata_url") == "https://example.com/client.json" + + def test_build_oauth_provider_stdio_ignores_cimd(self): + """build_oauth_provider should return None for stdio transport (no OAuth).""" + auth = MCPServerAuthSettings( + client_metadata_url="https://example.com/client.json" + ) + config = MCPServerSettings( + name="test", + transport="stdio", + command="echo", + auth=auth, + ) + + result = build_oauth_provider(config) + + assert result is None + + def test_build_oauth_provider_oauth_disabled_ignores_cimd(self): + """build_oauth_provider should return None when OAuth is disabled.""" + auth = MCPServerAuthSettings( + oauth=False, + client_metadata_url="https://example.com/client.json" + ) + config = MCPServerSettings( + name="test", + transport="http", + url="https://example.com/mcp", + auth=auth, + ) + + result = build_oauth_provider(config) + + assert result is None + + def test_build_oauth_provider_uses_loopback_ip(self, monkeypatch): + """build_oauth_provider should use 127.0.0.1 (loopback IP) for RFC 8252 compliance.""" + captured_kwargs = {} + + class MockOAuthClientProvider: + def __init__(self, **kwargs): + captured_kwargs.update(kwargs) + + monkeypatch.setattr( + "fast_agent.mcp.oauth_client.OAuthClientProvider", + MockOAuthClientProvider, + ) + + config = MCPServerSettings( + name="test", + transport="http", + url="https://example.com/mcp", + ) + + build_oauth_provider(config) + + # Check that redirect_uris use 127.0.0.1 instead of localhost + client_metadata = captured_kwargs.get("client_metadata") + assert client_metadata is not None + redirect_uris = [str(uri) for uri in client_metadata.redirect_uris] + assert all("127.0.0.1" in uri for uri in redirect_uris) + assert not any("localhost" in uri for uri in redirect_uris) + + def test_build_oauth_provider_registers_fallback_ports(self, monkeypatch): + """build_oauth_provider should register multiple ports for fallback support.""" + captured_kwargs = {} + + class MockOAuthClientProvider: + def __init__(self, **kwargs): + captured_kwargs.update(kwargs) + + monkeypatch.setattr( + "fast_agent.mcp.oauth_client.OAuthClientProvider", + MockOAuthClientProvider, + ) + + config = MCPServerSettings( + name="test", + transport="http", + url="https://example.com/mcp", + ) + + build_oauth_provider(config) + + client_metadata = captured_kwargs.get("client_metadata") + assert client_metadata is not None + redirect_uris = [str(uri) for uri in client_metadata.redirect_uris] + # Should have multiple redirect URIs for port fallback + assert len(redirect_uris) >= 3 + # Should include default port 3030 + assert any(":3030/" in uri for uri in redirect_uris) + + +class TestCallbackServerPortFallback: + """Test RFC 8252 compliant port fallback in _CallbackServer.""" + + def test_callback_server_uses_loopback_ip(self): + """_CallbackServer should bind to 127.0.0.1 for RFC 8252 compliance.""" + server = _CallbackServer(port=0, path="/callback") # Use ephemeral port + try: + server.start() + assert server.actual_port is not None + assert server.actual_port > 0 + # The server is bound to 127.0.0.1 + assert server._server.server_address[0] == "127.0.0.1" + finally: + server.stop() + + def test_callback_server_ephemeral_port(self): + """_CallbackServer should work with ephemeral port (0).""" + server = _CallbackServer(port=0, path="/callback") + try: + server.start() + # Should get a real port assigned + assert server.actual_port is not None + assert server.actual_port > 0 + finally: + server.stop() + + def test_callback_server_get_redirect_uri(self): + """get_redirect_uri should return the actual bound port.""" + server = _CallbackServer(port=0, path="/callback") + try: + server.start() + redirect_uri = server.get_redirect_uri() + assert redirect_uri.startswith("http://127.0.0.1:") + assert redirect_uri.endswith("/callback") + assert f":{server.actual_port}/" in redirect_uri + finally: + server.stop() + + def test_callback_server_port_fallback(self): + """_CallbackServer should fall back to next port if preferred is in use.""" + # Occupy the preferred port + blocker = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + blocker.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + blocker.bind(("127.0.0.1", 13030)) # Use unusual port to avoid conflicts + blocker.listen(1) + + # Try to start server on blocked port + server = _CallbackServer(port=13030, path="/callback") + try: + server.start() + # Should have fallen back to a different port + assert server.actual_port != 13030 + assert server.actual_port is not None + finally: + server.stop() + finally: + blocker.close() + + def test_callback_server_get_redirect_uri_before_start_raises(self): + """get_redirect_uri should raise if called before start().""" + server = _CallbackServer(port=3030, path="/callback") + with pytest.raises(RuntimeError, match="Server not started"): + server.get_redirect_uri()