88import socket
99import time
1010from collections .abc import Generator
11+ from datetime import timedelta
12+ from unittest .mock import AsyncMock , MagicMock , patch
1113
1214import anyio
1315import httpx
1921from starlette .routing import Mount
2022
2123import mcp .types as types
24+ from mcp .client .auth import OAuthClientProvider
2225from mcp .client .session import ClientSession
23- from mcp .client .streamable_http import streamablehttp_client
26+ from mcp .client .streamable_http import StreamableHTTPTransport , streamablehttp_client
2427from mcp .server import Server
2528from mcp .server .streamable_http import (
2629 MCP_SESSION_ID_HEADER ,
3336 StreamId ,
3437)
3538from mcp .server .streamable_http_manager import StreamableHTTPSessionManager
39+ from mcp .shared .auth import OAuthToken
3640from mcp .shared .exceptions import McpError
3741from mcp .shared .message import (
3842 ClientMessageMetadata ,
@@ -1063,3 +1067,151 @@ async def run_tool():
10631067 assert not any (
10641068 n in captured_notifications_pre for n in captured_notifications
10651069 )
1070+
1071+
1072+ @pytest .mark .anyio
1073+ async def test_streamablehttp_client_auth_token_headers (basic_server , basic_server_url ):
1074+ """Test that auth tokens are correctly added to request headers."""
1075+ # Create a mock OAuth provider that returns a test token
1076+ mock_provider = AsyncMock (spec = OAuthClientProvider )
1077+ mock_provider .get_token .return_value = OAuthToken (
1078+ access_token = "test-token-123" ,
1079+ token_type = "bearer" ,
1080+ expires_in = 3600 ,
1081+ refresh_token = "refresh-token-123" ,
1082+ )
1083+
1084+ # Create client with auth provider
1085+ async with streamablehttp_client (
1086+ f"{ basic_server_url } /mcp" ,
1087+ auth_provider = mock_provider ,
1088+ ) as (
1089+ read_stream ,
1090+ write_stream ,
1091+ _ ,
1092+ _ ,
1093+ ):
1094+ async with ClientSession (read_stream , write_stream ) as session :
1095+ # Initialize the session
1096+ result = await session .initialize ()
1097+ assert isinstance (result , InitializeResult )
1098+
1099+ # Verify the mock provider was called
1100+ mock_provider .get_token .assert_called ()
1101+
1102+ # Verify the token was added to headers
1103+ # We can't directly check the headers since they're internal to the transport
1104+ # But we can verify the provider was called and returned a token
1105+ assert mock_provider .get_token .return_value .access_token == "test-token-123"
1106+
1107+
1108+ @pytest .mark .anyio
1109+ async def test_streamablehttp_client_no_auth_headers (basic_server , basic_server_url ):
1110+ """Test that no auth headers are added when no auth provider is configured."""
1111+ # Create client without auth provider
1112+ async with streamablehttp_client (f"{ basic_server_url } /mcp" ) as (
1113+ read_stream ,
1114+ write_stream ,
1115+ _ ,
1116+ _ ,
1117+ ):
1118+ async with ClientSession (read_stream , write_stream ) as session :
1119+ # Initialize the session
1120+ result = await session .initialize ()
1121+ assert isinstance (result , InitializeResult )
1122+
1123+ # No auth headers should be present
1124+ # We can't directly check the headers since they're internal to the transport
1125+ # But we can verify the session works without auth
1126+
1127+
1128+ @pytest .mark .anyio
1129+ async def test_streamablehttp_client_auth_token_format (basic_server , basic_server_url ):
1130+ """Test that auth tokens are correctly formatted in headers."""
1131+ # Create a mock OAuth provider that returns a test token
1132+ mock_provider = AsyncMock (spec = OAuthClientProvider )
1133+ mock_provider .get_token .return_value = OAuthToken (
1134+ access_token = "test-token-123" ,
1135+ token_type = "bearer" ,
1136+ expires_in = 3600 ,
1137+ refresh_token = "refresh-token-123" ,
1138+ )
1139+
1140+ # Create a custom httpx client to capture headers
1141+ captured_headers = {}
1142+
1143+ async def capture_headers (request ):
1144+ captured_headers .update (request .headers )
1145+ return httpx .Response (200 )
1146+
1147+ async with httpx .AsyncClient (
1148+ base_url = basic_server_url ,
1149+ transport = httpx .MockTransport (capture_headers ),
1150+ ) as client :
1151+ # Create transport with auth provider
1152+ transport = StreamableHTTPTransport (
1153+ f"{ basic_server_url } /mcp" ,
1154+ auth_provider = mock_provider ,
1155+ )
1156+
1157+ # Get headers with auth token
1158+ headers = await transport ._get_request_headers ()
1159+
1160+ # Verify the Authorization header is correctly formatted
1161+ assert headers ["Authorization" ] == "Bearer test-token-123"
1162+
1163+
1164+ @pytest .mark .anyio
1165+ async def test_streamablehttp_client_token_refresh (basic_server , basic_server_url ):
1166+ """Test that expired tokens are automatically refreshed."""
1167+ # Create a mock OAuth provider that returns an expired token
1168+ mock_provider = AsyncMock (spec = OAuthClientProvider )
1169+ mock_provider .get_token .return_value = OAuthToken (
1170+ access_token = "expired-token-123" ,
1171+ token_type = "bearer" ,
1172+ expires_in = 0 , # Expired token
1173+ refresh_token = "refresh-token-123" ,
1174+ )
1175+
1176+ # Mock the refresh response
1177+ refreshed_token = OAuthToken (
1178+ access_token = "new-token-456" ,
1179+ token_type = "bearer" ,
1180+ expires_in = 3600 ,
1181+ refresh_token = "new-refresh-token-456" ,
1182+ )
1183+
1184+ # Create client with auth provider
1185+ async with streamablehttp_client (
1186+ f"{ basic_server_url } /mcp" ,
1187+ auth_provider = mock_provider ,
1188+ ) as (
1189+ read_stream ,
1190+ write_stream ,
1191+ _ ,
1192+ _ ,
1193+ ):
1194+ async with ClientSession (read_stream , write_stream ) as session :
1195+ # Mock the OAuthAuthorization class to verify refresh token usage
1196+ with patch ("mcp.client.streamable_http.OAuthAuthorization" ) as mock_auth :
1197+ # Configure the mock to return our refreshed token
1198+ mock_auth_instance = AsyncMock ()
1199+ mock_auth_instance .authorize .return_value = refreshed_token
1200+ mock_auth .return_value = mock_auth_instance
1201+
1202+ # Initialize the session
1203+ result = await session .initialize ()
1204+ assert isinstance (result , InitializeResult )
1205+
1206+ # Verify the mock provider was called to get the initial token
1207+ mock_provider .get_token .assert_called ()
1208+
1209+ # Verify that authorize was called with the expired token
1210+ mock_auth_instance .authorize .assert_called_once ()
1211+
1212+ # Verify the refreshed token was saved back to the provider
1213+ mock_provider .save_token .assert_called_once_with (refreshed_token )
1214+
1215+ # Verify the new token is being used in subsequent requests
1216+ # by checking the Authorization header format
1217+ assert mock_auth_instance .get_token .return_value .access_token == "new-token-456"
0 commit comments