88import socket
99import time
1010from collections .abc import Generator
11- from datetime import timedelta
12- from unittest .mock import AsyncMock , MagicMock , patch
11+ from unittest .mock import AsyncMock
1312
1413import anyio
1514import httpx
@@ -1069,60 +1068,31 @@ async def run_tool():
10691068 )
10701069
10711070
1072- @pytest .mark .anyio
1073- async def test_streamablehttp_client_auth_token_headers (basic_server , basic_server_url ):
1074- """Test that auth tokens are correctly added to request headers."""
1075- # Create a mock OAuth provider that returns a test token
1076- mock_provider = AsyncMock (spec = OAuthClientProvider )
1077- mock_provider .get_token .return_value = OAuthToken (
1078- access_token = "test-token-123" ,
1079- token_type = "bearer" ,
1080- expires_in = 3600 ,
1081- refresh_token = "refresh-token-123" ,
1082- )
1083-
1084- # Create client with auth provider
1085- async with streamablehttp_client (
1086- f"{ basic_server_url } /mcp" ,
1087- auth_provider = mock_provider ,
1088- ) as (
1089- read_stream ,
1090- write_stream ,
1091- _ ,
1092- _ ,
1093- ):
1094- async with ClientSession (read_stream , write_stream ) as session :
1095- # Initialize the session
1096- result = await session .initialize ()
1097- assert isinstance (result , InitializeResult )
1098-
1099- # Verify the mock provider was called
1100- mock_provider .get_token .assert_called ()
1101-
1102- # Verify the token was added to headers
1103- # We can't directly check the headers since they're internal to the transport
1104- # But we can verify the provider was called and returned a token
1105- assert mock_provider .get_token .return_value .access_token == "test-token-123"
1106-
1107-
11081071@pytest .mark .anyio
11091072async def test_streamablehttp_client_no_auth_headers (basic_server , basic_server_url ):
11101073 """Test that no auth headers are added when no auth provider is configured."""
11111074 # Create client without auth provider
1112- async with streamablehttp_client (f"{ basic_server_url } /mcp" ) as (
1113- read_stream ,
1114- write_stream ,
1115- _ ,
1116- _ ,
1117- ):
1118- async with ClientSession (read_stream , write_stream ) as session :
1119- # Initialize the session
1120- result = await session .initialize ()
1121- assert isinstance (result , InitializeResult )
11221075
1123- # No auth headers should be present
1124- # We can't directly check the headers since they're internal to the transport
1125- # But we can verify the session works without auth
1076+ captured_headers = {}
1077+
1078+ async def capture_headers (request ):
1079+ captured_headers .update (request .headers )
1080+ return httpx .Response (200 )
1081+
1082+ async with httpx .AsyncClient (
1083+ base_url = basic_server_url ,
1084+ transport = httpx .MockTransport (capture_headers ),
1085+ ) as client :
1086+ # Create transport with auth provider
1087+ transport = StreamableHTTPTransport (
1088+ f"{ basic_server_url } /mcp" ,
1089+ )
1090+
1091+ # Make a request using the client to trigger header capture
1092+ headers = await transport ._get_request_headers ()
1093+ await client .post ("/mcp" , headers = headers )
1094+ # Verify the Authorization header is correctly formatted
1095+ assert "authorization" not in captured_headers
11261096
11271097
11281098@pytest .mark .anyio
@@ -1154,64 +1124,9 @@ async def capture_headers(request):
11541124 auth_provider = mock_provider ,
11551125 )
11561126
1157- # Get headers with auth token
1127+ # Make a request using the client to trigger header capture
11581128 headers = await transport ._get_request_headers ()
1129+ await client .post ("/mcp" , headers = headers )
11591130
11601131 # Verify the Authorization header is correctly formatted
1161- assert headers ["Authorization" ] == "Bearer test-token-123"
1162-
1163-
1164- @pytest .mark .anyio
1165- async def test_streamablehttp_client_token_refresh (basic_server , basic_server_url ):
1166- """Test that expired tokens are automatically refreshed."""
1167- # Create a mock OAuth provider that returns an expired token
1168- mock_provider = AsyncMock (spec = OAuthClientProvider )
1169- mock_provider .get_token .return_value = OAuthToken (
1170- access_token = "expired-token-123" ,
1171- token_type = "bearer" ,
1172- expires_in = 0 , # Expired token
1173- refresh_token = "refresh-token-123" ,
1174- )
1175-
1176- # Mock the refresh response
1177- refreshed_token = OAuthToken (
1178- access_token = "new-token-456" ,
1179- token_type = "bearer" ,
1180- expires_in = 3600 ,
1181- refresh_token = "new-refresh-token-456" ,
1182- )
1183-
1184- # Create client with auth provider
1185- async with streamablehttp_client (
1186- f"{ basic_server_url } /mcp" ,
1187- auth_provider = mock_provider ,
1188- ) as (
1189- read_stream ,
1190- write_stream ,
1191- _ ,
1192- _ ,
1193- ):
1194- async with ClientSession (read_stream , write_stream ) as session :
1195- # Mock the OAuthAuthorization class to verify refresh token usage
1196- with patch ("mcp.client.streamable_http.OAuthAuthorization" ) as mock_auth :
1197- # Configure the mock to return our refreshed token
1198- mock_auth_instance = AsyncMock ()
1199- mock_auth_instance .authorize .return_value = refreshed_token
1200- mock_auth .return_value = mock_auth_instance
1201-
1202- # Initialize the session
1203- result = await session .initialize ()
1204- assert isinstance (result , InitializeResult )
1205-
1206- # Verify the mock provider was called to get the initial token
1207- mock_provider .get_token .assert_called ()
1208-
1209- # Verify that authorize was called with the expired token
1210- mock_auth_instance .authorize .assert_called_once ()
1211-
1212- # Verify the refreshed token was saved back to the provider
1213- mock_provider .save_token .assert_called_once_with (refreshed_token )
1214-
1215- # Verify the new token is being used in subsequent requests
1216- # by checking the Authorization header format
1217- assert mock_auth_instance .get_token .return_value .access_token == "new-token-456"
1132+ assert captured_headers ["authorization" ] == "Bearer test-token-123"
0 commit comments