diff --git a/src/mcp_agent/oauth/flow.py b/src/mcp_agent/oauth/flow.py index e2c523044..5488aecb8 100644 --- a/src/mcp_agent/oauth/flow.py +++ b/src/mcp_agent/oauth/flow.py @@ -136,7 +136,7 @@ async def authorize( import urllib.parse authorize_url = httpx.URL( - str(auth_metadata.authorization_endpoint) + str(auth_metadata.authorization_endpoint).rstrip("/") + "?" + urllib.parse.urlencode(params) ) diff --git a/src/mcp_agent/oauth/metadata.py b/src/mcp_agent/oauth/metadata.py index 7d2c974c4..0b6c62353 100644 --- a/src/mcp_agent/oauth/metadata.py +++ b/src/mcp_agent/oauth/metadata.py @@ -69,14 +69,19 @@ def select_authorization_server( "Protected resource metadata did not include authorization servers" ) - if preferred and preferred in candidates: - return preferred - if preferred: + preferred_normalized = preferred.rstrip("/") + candidates_normalized = [c.rstrip("/") for c in candidates] + + for i, candidate_normalized in enumerate(candidates_normalized): + if candidate_normalized == preferred_normalized: + return candidates[i] + logger.warning( "Preferred authorization server not listed; falling back to first entry", data={"preferred": preferred, "candidates": candidates}, ) + return candidates[0] diff --git a/src/mcp_agent/server/token_verifier.py b/src/mcp_agent/server/token_verifier.py index 4fff7fbc9..f5637e361 100644 --- a/src/mcp_agent/server/token_verifier.py +++ b/src/mcp_agent/server/token_verifier.py @@ -155,12 +155,14 @@ async def _introspect(self, token: str) -> MCPAccessToken | None: return None if self._settings.issuer_url and payload.get("iss"): - if str(payload.get("iss")) != str(self._settings.issuer_url): + expected_issuer = str(self._settings.issuer_url).rstrip("/") + actual_issuer = str(payload.get("iss")).rstrip("/") + if actual_issuer != expected_issuer: logger.warning( "Token issuer mismatch", data={ - "expected": str(self._settings.issuer_url), - "actual": payload.get("iss"), + "expected": expected_issuer, + "actual": actual_issuer, }, ) return None @@ -241,8 +243,10 @@ def _validate_audiences(self, token_audiences: List[str]) -> bool: return False # RFC 9068: Token MUST contain at least one expected audience - valid_audiences = set(self._settings.expected_audiences) - token_audience_set = set(token_audiences) + valid_audiences = set( + aud.rstrip("/") for aud in self._settings.expected_audiences + ) + token_audience_set = set(aud.rstrip("/") for aud in token_audiences) if not valid_audiences.intersection(token_audience_set): logger.warning( diff --git a/tests/test_oauth_utils.py b/tests/test_oauth_utils.py index 581608524..6207f9660 100644 --- a/tests/test_oauth_utils.py +++ b/tests/test_oauth_utils.py @@ -2,6 +2,7 @@ import asyncio import pathlib import sys +from typing import Any, Dict import pytest @@ -75,6 +76,65 @@ def test_select_authorization_server_prefers_explicit(): ) +def test_select_authorization_server_with_serialized_config(): + """Test that authorization server selection works after config json serialization. + + When MCPOAuthClientSettings is dumped with mode='json', the authorization_server + AnyHttpUrl field gets a trailing slash. This test ensures select_authorization_server + handles this correctly. + """ + from mcp_agent.config import MCPOAuthClientSettings + + oauth_config = MCPOAuthClientSettings( + enabled=True, + authorization_server="https://auth.example.com", + resource="https://api.example.com", + client_id="test_client", + ) + + dumped_config = oauth_config.model_dump(mode="json") + reloaded_config = MCPOAuthClientSettings(**dumped_config) + + metadata = ProtectedResourceMetadata( + resource="https://api.example.com", + authorization_servers=[ + "https://auth.example.com", + "https://other-auth.example.com", + ], + ) + + dumped_metadata = metadata.model_dump(mode="json") + reloaded_metadata = ProtectedResourceMetadata(**dumped_metadata) + + preferred = str(reloaded_config.authorization_server) + selected = select_authorization_server(reloaded_metadata, preferred) + + assert selected.rstrip("/") == "https://auth.example.com" + + +def test_select_authorization_server_trailing_slash_mismatch(): + """Test trailing slash handling in select_authorization_server with various combinations.""" + # Test case 1: preferred has trailing slash, candidates don't + metadata1 = ProtectedResourceMetadata( + resource="https://api.example.com", + authorization_servers=["https://auth.example.com", "https://other.example.com"], + ) + + selected1 = select_authorization_server(metadata1, "https://auth.example.com/") + assert selected1.rstrip("/") == "https://auth.example.com" + + # Test case 2: preferred doesn't have trailing slash, candidates do + metadata2 = ProtectedResourceMetadata( + resource="https://api.example.com", + authorization_servers=[ + "https://auth.example.com/", + "https://other.example.com/", + ], + ) + selected2 = select_authorization_server(metadata2, "https://auth.example.com") + assert selected2.rstrip("/") == "https://auth.example.com" + + def test_normalize_resource_with_fallback(): assert ( normalize_resource("https://example.com/api", None) == "https://example.com/api" @@ -99,6 +159,27 @@ def test_oauth_loopback_ports_config_defaults(): assert 33418 in s.loopback_ports +def test_oauth_callback_base_url_with_serialized_config(): + """Test that callback_base_url works correctly after json serialization. + + When OAuthSettings is dumped with mode='json', the callback_base_url AnyHttpUrl + field gets a trailing slash. + """ + from mcp_agent.config import OAuthSettings + + settings = OAuthSettings(callback_base_url="https://callback.example.com") + dumped = settings.model_dump(mode="json") + reloaded = OAuthSettings(**dumped) + + flow_id = "test_flow_123" + if reloaded.callback_base_url: + constructed_url = f"{str(reloaded.callback_base_url).rstrip('/')}/internal/oauth/callback/{flow_id}" + + assert "//" not in constructed_url.replace("https://", "") + assert constructed_url.endswith(flow_id) + assert constructed_url.startswith("https://callback.example.com/") + + @pytest.mark.asyncio async def test_callback_registry_state_mapping(): from mcp_agent.oauth.callbacks import OAuthCallbackRegistry @@ -110,3 +191,79 @@ async def test_callback_registry_state_mapping(): assert delivered is True result = await asyncio.wait_for(fut, timeout=0.2) assert result["code"] == "abc" + + +@pytest.mark.asyncio +async def test_authorization_url_construction_with_trailing_slash(): + """Test that authorization URL is constructed correctly when endpoint has trailing slash.""" + from mcp_agent.oauth.flow import AuthorizationFlowCoordinator + from mcp_agent.config import OAuthSettings, MCPOAuthClientSettings + from mcp_agent.core.context import Context + from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata + from unittest.mock import MagicMock, patch + import httpx + + oauth_settings = OAuthSettings() + context = MagicMock(spec=Context) + from mcp_agent.oauth.identity import OAuthUserIdentity + + user = OAuthUserIdentity(subject="user123", provider="test") + + oauth_config = MCPOAuthClientSettings( + enabled=True, + client_id="test_client", + authorization_server="https://auth.example.com", + resource="https://api.example.com", + ) + + resource_metadata = ProtectedResourceMetadata( + resource="https://api.example.com/", + authorization_servers=["https://auth.example.com/"], + ) + + auth_metadata = OAuthMetadata( + issuer="https://auth.example.com/", + authorization_endpoint="https://auth.example.com/authorize/", + token_endpoint="https://auth.example.com/token/", + ) + + http_client = httpx.AsyncClient() + flow = AuthorizationFlowCoordinator( + http_client=http_client, settings=oauth_settings + ) + + captured_payload: Dict[str, Any] | None = None + + async def mock_send_auth_request(_ctx, payload: Dict[str, Any]): + nonlocal captured_payload + captured_payload = payload + # Simulate user declining to test the flow without needing real callback + raise ConnectionAbortedError("test_exception") + + with patch( + "mcp_agent.oauth.flow._send_auth_request", side_effect=mock_send_auth_request + ): + try: + await flow.authorize( + context=context, + user=user, + server_name="test_server", + oauth_config=oauth_config, + resource="https://api.example.com", + authorization_server_url="https://auth.example.com", + resource_metadata=resource_metadata, + auth_metadata=auth_metadata, + scopes=["read"], + ) + except ConnectionAbortedError: + pass # Expected to fail due to mock + + await http_client.aclose() + assert captured_payload is not None, "captured_payload should have been set by mock" + + # Type narrowing for Pylint + if captured_payload is not None: + url = captured_payload["url"] + assert "authorize/?" not in url + assert "authorize?" in url + assert url.startswith("https://auth.example.com/authorize?") diff --git a/tests/test_token_verifier.py b/tests/test_token_verifier.py index b611687f9..3fae706cb 100644 --- a/tests/test_token_verifier.py +++ b/tests/test_token_verifier.py @@ -855,3 +855,101 @@ async def test_audience_validation_failure_through_introspect(): assert token is None await verifier.aclose() + + +@pytest.mark.asyncio +async def test_issuer_comparison_with_trailing_slash_from_token(): + """Test that issuer comparison works when token has trailing slash. + + When config is loaded/dumped with mode='json', AnyHttpUrl fields may gain + trailing slashes. This test ensures the issuer comparison in token_verifier.py:158 + handles this correctly. + """ + settings = MCPAuthorizationServerSettings( + enabled=True, + issuer_url="https://auth.example.com", + resource_server_url="https://api.example.com", + expected_audiences=["https://api.example.com"], + ) + + # Dump with mode="json" and reload to simulate config loading (with trailing slashes) + dumped = settings.model_dump(mode="json") + reloaded_settings = MCPAuthorizationServerSettings(**dumped) + + verifier = MCPAgentTokenVerifier(reloaded_settings) + + metadata_response = Mock() + metadata_response.status_code = 200 + metadata_response.json.return_value = { + "issuer": "https://auth.example.com", + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + "introspection_endpoint": "https://auth.example.com/introspect", + "response_types_supported": ["code"], + } + + introspect_response = Mock() + introspect_response.status_code = 200 + introspect_response.json.return_value = { + "active": True, + "aud": "https://api.example.com/", + "sub": "user123", + "exp": 9999999999, + "iss": "https://auth.example.com/", # trailing slash + } + + verifier._client.get = AsyncMock(return_value=metadata_response) + verifier._client.post = AsyncMock(return_value=introspect_response) + + token = await verifier._introspect("test_token") + + assert token is not None + assert token.subject == "user123" + + await verifier.aclose() + + +@pytest.mark.asyncio +async def test_issuer_comparison_config_trailing_slash_token_without(): + """Test issuer comparison when config has trailing slash but token doesn't.""" + settings = MCPAuthorizationServerSettings( + enabled=True, + issuer_url="https://auth.example.com", + resource_server_url="https://api.example.com", + expected_audiences=["https://api.example.com"], + ) + + dumped = settings.model_dump(mode="json") + reloaded_settings = MCPAuthorizationServerSettings(**dumped) + + verifier = MCPAgentTokenVerifier(reloaded_settings) + + metadata_response = Mock() + metadata_response.status_code = 200 + metadata_response.json.return_value = { + "issuer": "https://auth.example.com", + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token", + "introspection_endpoint": "https://auth.example.com/introspect", + "response_types_supported": ["code"], + } + + introspect_response = Mock() + introspect_response.status_code = 200 + introspect_response.json.return_value = { + "active": True, + "aud": "https://api.example.com", + "sub": "user123", + "exp": 9999999999, + "iss": "https://auth.example.com", # No trailing slash + } + + verifier._client.get = AsyncMock(return_value=metadata_response) + verifier._client.post = AsyncMock(return_value=introspect_response) + + token = await verifier._introspect("test_token") + + assert token is not None + assert token.subject == "user123" + + await verifier.aclose()