Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/fast_agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,35 @@ class MCPServerAuthSettings(BaseModel):
# Token persistence: use OS keychain via 'keyring' by default; fallback to 'memory'.
persist: Literal["keyring", "memory"] = "keyring"

# Client ID Metadata Document (CIMD) URL.
# When provided and the server advertises client_id_metadata_document_supported=true,
# this URL will be used as the client_id instead of performing dynamic client registration.
# Must be a valid HTTPS URL with a non-root pathname (e.g., https://example.com/client.json).
# See: https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization
client_metadata_url: str | None = None

model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)

@field_validator("client_metadata_url", mode="after")
@classmethod
def _validate_client_metadata_url(cls, v: str | None) -> str | None:
"""Validate that client_metadata_url is a valid HTTPS URL with a non-root path."""
if v is None:
return None
from urllib.parse import urlparse

try:
parsed = urlparse(v)
if parsed.scheme != "https":
raise ValueError("client_metadata_url must use HTTPS scheme")
if parsed.path in ("", "/"):
raise ValueError("client_metadata_url must have a non-root pathname")
return v
except ValueError:
raise
except Exception as e:
raise ValueError(f"Invalid client_metadata_url: {e}")


class MCPSamplingSettings(BaseModel):
model: str = "gpt-5-mini.low"
Expand Down
88 changes: 79 additions & 9 deletions src/fast_agent/mcp/oauth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,28 @@ def log_message(self, format: str, *args: Any) -> None: # silence default loggi


class _CallbackServer:
"""Simple background HTTP server to receive a single OAuth callback."""
"""Simple background HTTP server to receive a single OAuth callback.

Uses 127.0.0.1 (loopback IP) instead of localhost for RFC 8252 compliance.
Per RFC 8252 Section 7.3, authorization servers MUST allow any port for
loopback IP redirect URIs, enabling dynamic port allocation.
"""

# Fallback ports to try if preferred port is unavailable
FALLBACK_PORTS = [3030, 3031, 3032, 8080, 0] # 0 = ephemeral port

def __init__(self, port: int, path: str) -> None:
self._port = port
self._preferred_port = port
self._path = path.rstrip("/") or "/callback"
self._result = _CallbackResult()
self._server: HTTPServer | None = None
self._thread: threading.Thread | None = None
self._actual_port: int | None = None

@property
def actual_port(self) -> int | None:
"""Return the actual port the server bound to (may differ from preferred)."""
return self._actual_port

def _make_handler(self) -> Callable[..., BaseHTTPRequestHandler]:
result = self._result
Expand All @@ -132,11 +146,46 @@ def handler(*args, **kwargs):

return handler

def _try_bind(self, port: int) -> HTTPServer | None:
"""Try to bind to the given port. Returns server if successful, None otherwise."""
try:
# Use 127.0.0.1 (loopback IP) for RFC 8252 compliance
server = HTTPServer(("127.0.0.1", port), self._make_handler())
return server
except OSError as e:
# EADDRINUSE (98 on Linux, 48 on macOS) or similar
logger.debug(f"Port {port} unavailable: {e}")
return None

def start(self) -> None:
self._server = HTTPServer(("localhost", self._port), self._make_handler())
self._thread = threading.Thread(target=self._server.serve_forever, daemon=True)
self._thread.start()
logger.info(f"OAuth callback server listening on http://localhost:{self._port}{self._path}")
"""Start the callback server, trying fallback ports if preferred is unavailable."""
# Build list of ports to try: preferred first, then fallbacks
ports_to_try = [self._preferred_port]
for p in self.FALLBACK_PORTS:
if p not in ports_to_try:
ports_to_try.append(p)

for port in ports_to_try:
server = self._try_bind(port)
if server is not None:
self._server = server
# Get actual port (important when using ephemeral port 0)
self._actual_port = self._server.server_address[1]
self._thread = threading.Thread(target=self._server.serve_forever, daemon=True)
self._thread.start()
logger.info(
f"OAuth callback server listening on http://127.0.0.1:{self._actual_port}{self._path}"
)
if self._actual_port != self._preferred_port:
logger.info(
f"Note: Using port {self._actual_port} instead of preferred port {self._preferred_port}"
)
return

raise OSError(
f"Could not bind to any port. Tried: {ports_to_try}. "
"All ports may be in use."
)

def stop(self) -> None:
if self._server:
Expand All @@ -158,6 +207,12 @@ def wait(self, timeout_seconds: int = 300) -> tuple[str, str | None]:
time.sleep(0.1)
raise TimeoutError("Timeout waiting for OAuth callback")

def get_redirect_uri(self) -> str:
"""Return the actual redirect URI based on bound port."""
if self._actual_port is None:
raise RuntimeError("Server not started; cannot determine redirect URI")
return f"http://127.0.0.1:{self._actual_port}{self._path}"


def _derive_base_server_url(url: str | None) -> str | None:
"""Derive the base server URL for OAuth discovery from an MCP endpoint URL.
Expand Down Expand Up @@ -417,6 +472,7 @@ def build_oauth_provider(server_config: MCPServerSettings) -> OAuthClientProvide
redirect_path = "/callback"
scope_value: str | None = None
persist_mode: str = "keyring"
client_metadata_url: str | None = None

if server_config.auth is not None:
try:
Expand All @@ -425,6 +481,7 @@ def build_oauth_provider(server_config: MCPServerSettings) -> OAuthClientProvide
redirect_path = getattr(server_config.auth, "redirect_path", "/callback")
scope_field = getattr(server_config.auth, "scope", None)
persist_mode = getattr(server_config.auth, "persist", "keyring")
client_metadata_url = getattr(server_config.auth, "client_metadata_url", None)
if isinstance(scope_field, list):
scope_value = " ".join(scope_field)
elif isinstance(scope_field, str):
Expand All @@ -440,11 +497,23 @@ def build_oauth_provider(server_config: MCPServerSettings) -> OAuthClientProvide
# No usable URL -> cannot build provider
return None

# Construct client metadata with minimal defaults
redirect_uri = f"http://localhost:{redirect_port}{redirect_path}"
# Construct client metadata with minimal defaults.
# Use 127.0.0.1 (loopback IP) for RFC 8252 compliance. Per RFC 8252 Section 7.3,
# authorization servers MUST allow any port for loopback IP redirect URIs.
# We register multiple redirect URIs to support port fallback for servers that
# don't fully implement RFC 8252's dynamic port matching.
redirect_uris: list[AnyUrl] = []
# Build list of ports: preferred first, then fallbacks
ports_for_registration = [redirect_port]
for p in _CallbackServer.FALLBACK_PORTS:
if p != 0 and p not in ports_for_registration: # Skip ephemeral port (0)
ports_for_registration.append(p)
for port in ports_for_registration:
redirect_uris.append(AnyUrl(f"http://127.0.0.1:{port}{redirect_path}"))

metadata_kwargs: dict[str, Any] = {
"client_name": "fast-agent",
"redirect_uris": [AnyUrl(redirect_uri)],
"redirect_uris": redirect_uris,
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
}
Expand Down Expand Up @@ -504,6 +573,7 @@ async def _callback_handler() -> tuple[str, str | None]:
storage=storage,
redirect_handler=_redirect_handler,
callback_handler=_callback_handler,
client_metadata_url=client_metadata_url,
)

return provider
Loading
Loading