Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mcp_agent/oauth/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ async def authorize(
import urllib.parse

authorize_url = httpx.URL(
str(auth_metadata.authorization_endpoint)
str(auth_metadata.authorization_endpoint).rstrip("/")
+ "?"
+ urllib.parse.urlencode(params)
)
Expand Down
11 changes: 8 additions & 3 deletions src/mcp_agent/oauth/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,19 @@ def select_authorization_server(
"Protected resource metadata did not include authorization servers"
)

if preferred and preferred in candidates:
return preferred

if preferred:
preferred_normalized = preferred.rstrip("/")
candidates_normalized = [c.rstrip("/") for c in candidates]

for i, candidate_normalized in enumerate(candidates_normalized):
if candidate_normalized == preferred_normalized:
return candidates[i]

logger.warning(
"Preferred authorization server not listed; falling back to first entry",
data={"preferred": preferred, "candidates": candidates},
)

return candidates[0]


Expand Down
14 changes: 9 additions & 5 deletions src/mcp_agent/server/token_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,14 @@ async def _introspect(self, token: str) -> MCPAccessToken | None:
return None

if self._settings.issuer_url and payload.get("iss"):
if str(payload.get("iss")) != str(self._settings.issuer_url):
expected_issuer = str(self._settings.issuer_url).rstrip("/")
actual_issuer = str(payload.get("iss")).rstrip("/")
if actual_issuer != expected_issuer:
logger.warning(
"Token issuer mismatch",
data={
"expected": str(self._settings.issuer_url),
"actual": payload.get("iss"),
"expected": expected_issuer,
"actual": actual_issuer,
},
)
return None
Expand Down Expand Up @@ -241,8 +243,10 @@ def _validate_audiences(self, token_audiences: List[str]) -> bool:
return False

# RFC 9068: Token MUST contain at least one expected audience
valid_audiences = set(self._settings.expected_audiences)
token_audience_set = set(token_audiences)
valid_audiences = set(
aud.rstrip("/") for aud in self._settings.expected_audiences
)
token_audience_set = set(aud.rstrip("/") for aud in token_audiences)

if not valid_audiences.intersection(token_audience_set):
logger.warning(
Expand Down
157 changes: 157 additions & 0 deletions tests/test_oauth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import asyncio
import pathlib
import sys
from typing import Any, Dict

import pytest

Expand Down Expand Up @@ -75,6 +76,65 @@ def test_select_authorization_server_prefers_explicit():
)


def test_select_authorization_server_with_serialized_config():
"""Test that authorization server selection works after config json serialization.

When MCPOAuthClientSettings is dumped with mode='json', the authorization_server
AnyHttpUrl field gets a trailing slash. This test ensures select_authorization_server
handles this correctly.
"""
from mcp_agent.config import MCPOAuthClientSettings

oauth_config = MCPOAuthClientSettings(
enabled=True,
authorization_server="https://auth.example.com",
resource="https://api.example.com",
client_id="test_client",
)

dumped_config = oauth_config.model_dump(mode="json")
reloaded_config = MCPOAuthClientSettings(**dumped_config)

metadata = ProtectedResourceMetadata(
resource="https://api.example.com",
authorization_servers=[
"https://auth.example.com",
"https://other-auth.example.com",
],
)

dumped_metadata = metadata.model_dump(mode="json")
reloaded_metadata = ProtectedResourceMetadata(**dumped_metadata)

preferred = str(reloaded_config.authorization_server)
selected = select_authorization_server(reloaded_metadata, preferred)

assert selected.rstrip("/") == "https://auth.example.com"


def test_select_authorization_server_trailing_slash_mismatch():
"""Test trailing slash handling in select_authorization_server with various combinations."""
# Test case 1: preferred has trailing slash, candidates don't
metadata1 = ProtectedResourceMetadata(
resource="https://api.example.com",
authorization_servers=["https://auth.example.com", "https://other.example.com"],
)

selected1 = select_authorization_server(metadata1, "https://auth.example.com/")
assert selected1.rstrip("/") == "https://auth.example.com"

# Test case 2: preferred doesn't have trailing slash, candidates do
metadata2 = ProtectedResourceMetadata(
resource="https://api.example.com",
authorization_servers=[
"https://auth.example.com/",
"https://other.example.com/",
],
)
selected2 = select_authorization_server(metadata2, "https://auth.example.com")
assert selected2.rstrip("/") == "https://auth.example.com"


def test_normalize_resource_with_fallback():
assert (
normalize_resource("https://example.com/api", None) == "https://example.com/api"
Expand All @@ -99,6 +159,27 @@ def test_oauth_loopback_ports_config_defaults():
assert 33418 in s.loopback_ports


def test_oauth_callback_base_url_with_serialized_config():
"""Test that callback_base_url works correctly after json serialization.

When OAuthSettings is dumped with mode='json', the callback_base_url AnyHttpUrl
field gets a trailing slash.
"""
from mcp_agent.config import OAuthSettings

settings = OAuthSettings(callback_base_url="https://callback.example.com")
dumped = settings.model_dump(mode="json")
reloaded = OAuthSettings(**dumped)

flow_id = "test_flow_123"
if reloaded.callback_base_url:
constructed_url = f"{str(reloaded.callback_base_url).rstrip('/')}/internal/oauth/callback/{flow_id}"

assert "//" not in constructed_url.replace("https://", "")
assert constructed_url.endswith(flow_id)
assert constructed_url.startswith("https://callback.example.com/")


@pytest.mark.asyncio
async def test_callback_registry_state_mapping():
from mcp_agent.oauth.callbacks import OAuthCallbackRegistry
Expand All @@ -110,3 +191,79 @@ async def test_callback_registry_state_mapping():
assert delivered is True
result = await asyncio.wait_for(fut, timeout=0.2)
assert result["code"] == "abc"


@pytest.mark.asyncio
async def test_authorization_url_construction_with_trailing_slash():
"""Test that authorization URL is constructed correctly when endpoint has trailing slash."""
from mcp_agent.oauth.flow import AuthorizationFlowCoordinator
from mcp_agent.config import OAuthSettings, MCPOAuthClientSettings
from mcp_agent.core.context import Context
from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata
from unittest.mock import MagicMock, patch
import httpx

oauth_settings = OAuthSettings()
context = MagicMock(spec=Context)
from mcp_agent.oauth.identity import OAuthUserIdentity

user = OAuthUserIdentity(subject="user123", provider="test")

oauth_config = MCPOAuthClientSettings(
enabled=True,
client_id="test_client",
authorization_server="https://auth.example.com",
resource="https://api.example.com",
)

resource_metadata = ProtectedResourceMetadata(
resource="https://api.example.com/",
authorization_servers=["https://auth.example.com/"],
)

auth_metadata = OAuthMetadata(
issuer="https://auth.example.com/",
authorization_endpoint="https://auth.example.com/authorize/",
token_endpoint="https://auth.example.com/token/",
)

http_client = httpx.AsyncClient()
flow = AuthorizationFlowCoordinator(
http_client=http_client, settings=oauth_settings
)

captured_payload: Dict[str, Any] | None = None

async def mock_send_auth_request(_ctx, payload: Dict[str, Any]):
nonlocal captured_payload
captured_payload = payload
# Simulate user declining to test the flow without needing real callback
raise ConnectionAbortedError("test_exception")

with patch(
"mcp_agent.oauth.flow._send_auth_request", side_effect=mock_send_auth_request
):
try:
await flow.authorize(
context=context,
user=user,
server_name="test_server",
oauth_config=oauth_config,
resource="https://api.example.com",
authorization_server_url="https://auth.example.com",
resource_metadata=resource_metadata,
auth_metadata=auth_metadata,
scopes=["read"],
)
except ConnectionAbortedError:
pass # Expected to fail due to mock

await http_client.aclose()
assert captured_payload is not None, "captured_payload should have been set by mock"

# Type narrowing for Pylint
if captured_payload is not None:
url = captured_payload["url"]
assert "authorize/?" not in url
assert "authorize?" in url
assert url.startswith("https://auth.example.com/authorize?")
98 changes: 98 additions & 0 deletions tests/test_token_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,3 +855,101 @@ async def test_audience_validation_failure_through_introspect():
assert token is None

await verifier.aclose()


@pytest.mark.asyncio
async def test_issuer_comparison_with_trailing_slash_from_token():
"""Test that issuer comparison works when token has trailing slash.

When config is loaded/dumped with mode='json', AnyHttpUrl fields may gain
trailing slashes. This test ensures the issuer comparison in token_verifier.py:158
handles this correctly.
"""
settings = MCPAuthorizationServerSettings(
enabled=True,
issuer_url="https://auth.example.com",
resource_server_url="https://api.example.com",
expected_audiences=["https://api.example.com"],
)

# Dump with mode="json" and reload to simulate config loading (with trailing slashes)
dumped = settings.model_dump(mode="json")
reloaded_settings = MCPAuthorizationServerSettings(**dumped)

verifier = MCPAgentTokenVerifier(reloaded_settings)

metadata_response = Mock()
metadata_response.status_code = 200
metadata_response.json.return_value = {
"issuer": "https://auth.example.com",
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
"introspection_endpoint": "https://auth.example.com/introspect",
"response_types_supported": ["code"],
}

introspect_response = Mock()
introspect_response.status_code = 200
introspect_response.json.return_value = {
"active": True,
"aud": "https://api.example.com/",
"sub": "user123",
"exp": 9999999999,
"iss": "https://auth.example.com/", # trailing slash
}

verifier._client.get = AsyncMock(return_value=metadata_response)
verifier._client.post = AsyncMock(return_value=introspect_response)

token = await verifier._introspect("test_token")

assert token is not None
assert token.subject == "user123"

await verifier.aclose()


@pytest.mark.asyncio
async def test_issuer_comparison_config_trailing_slash_token_without():
"""Test issuer comparison when config has trailing slash but token doesn't."""
settings = MCPAuthorizationServerSettings(
enabled=True,
issuer_url="https://auth.example.com",
resource_server_url="https://api.example.com",
expected_audiences=["https://api.example.com"],
)

dumped = settings.model_dump(mode="json")
reloaded_settings = MCPAuthorizationServerSettings(**dumped)

verifier = MCPAgentTokenVerifier(reloaded_settings)

metadata_response = Mock()
metadata_response.status_code = 200
metadata_response.json.return_value = {
"issuer": "https://auth.example.com",
"authorization_endpoint": "https://auth.example.com/authorize",
"token_endpoint": "https://auth.example.com/token",
"introspection_endpoint": "https://auth.example.com/introspect",
"response_types_supported": ["code"],
}

introspect_response = Mock()
introspect_response.status_code = 200
introspect_response.json.return_value = {
"active": True,
"aud": "https://api.example.com",
"sub": "user123",
"exp": 9999999999,
"iss": "https://auth.example.com", # No trailing slash
}

verifier._client.get = AsyncMock(return_value=metadata_response)
verifier._client.post = AsyncMock(return_value=introspect_response)

token = await verifier._introspect("test_token")

assert token is not None
assert token.subject == "user123"

await verifier.aclose()
Loading