Skip to content

Commit 155af0d

Browse files
committed
Add streamable http tests
1 parent e122f54 commit 155af0d

File tree

1 file changed

+24
-109
lines changed

1 file changed

+24
-109
lines changed

tests/shared/test_streamable_http.py

Lines changed: 24 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
import socket
99
import time
1010
from collections.abc import Generator
11-
from datetime import timedelta
12-
from unittest.mock import AsyncMock, MagicMock, patch
11+
from unittest.mock import AsyncMock
1312

1413
import anyio
1514
import httpx
@@ -1069,60 +1068,31 @@ async def run_tool():
10691068
)
10701069

10711070

1072-
@pytest.mark.anyio
1073-
async def test_streamablehttp_client_auth_token_headers(basic_server, basic_server_url):
1074-
"""Test that auth tokens are correctly added to request headers."""
1075-
# Create a mock OAuth provider that returns a test token
1076-
mock_provider = AsyncMock(spec=OAuthClientProvider)
1077-
mock_provider.get_token.return_value = OAuthToken(
1078-
access_token="test-token-123",
1079-
token_type="bearer",
1080-
expires_in=3600,
1081-
refresh_token="refresh-token-123",
1082-
)
1083-
1084-
# Create client with auth provider
1085-
async with streamablehttp_client(
1086-
f"{basic_server_url}/mcp",
1087-
auth_provider=mock_provider,
1088-
) as (
1089-
read_stream,
1090-
write_stream,
1091-
_,
1092-
_,
1093-
):
1094-
async with ClientSession(read_stream, write_stream) as session:
1095-
# Initialize the session
1096-
result = await session.initialize()
1097-
assert isinstance(result, InitializeResult)
1098-
1099-
# Verify the mock provider was called
1100-
mock_provider.get_token.assert_called()
1101-
1102-
# Verify the token was added to headers
1103-
# We can't directly check the headers since they're internal to the transport
1104-
# But we can verify the provider was called and returned a token
1105-
assert mock_provider.get_token.return_value.access_token == "test-token-123"
1106-
1107-
11081071
@pytest.mark.anyio
11091072
async def test_streamablehttp_client_no_auth_headers(basic_server, basic_server_url):
11101073
"""Test that no auth headers are added when no auth provider is configured."""
11111074
# Create client without auth provider
1112-
async with streamablehttp_client(f"{basic_server_url}/mcp") as (
1113-
read_stream,
1114-
write_stream,
1115-
_,
1116-
_,
1117-
):
1118-
async with ClientSession(read_stream, write_stream) as session:
1119-
# Initialize the session
1120-
result = await session.initialize()
1121-
assert isinstance(result, InitializeResult)
11221075

1123-
# No auth headers should be present
1124-
# We can't directly check the headers since they're internal to the transport
1125-
# But we can verify the session works without auth
1076+
captured_headers = {}
1077+
1078+
async def capture_headers(request):
1079+
captured_headers.update(request.headers)
1080+
return httpx.Response(200)
1081+
1082+
async with httpx.AsyncClient(
1083+
base_url=basic_server_url,
1084+
transport=httpx.MockTransport(capture_headers),
1085+
) as client:
1086+
# Create transport with auth provider
1087+
transport = StreamableHTTPTransport(
1088+
f"{basic_server_url}/mcp",
1089+
)
1090+
1091+
# Make a request using the client to trigger header capture
1092+
headers = await transport._get_request_headers()
1093+
await client.post("/mcp", headers=headers)
1094+
# Verify the Authorization header is correctly formatted
1095+
assert "authorization" not in captured_headers
11261096

11271097

11281098
@pytest.mark.anyio
@@ -1154,64 +1124,9 @@ async def capture_headers(request):
11541124
auth_provider=mock_provider,
11551125
)
11561126

1157-
# Get headers with auth token
1127+
# Make a request using the client to trigger header capture
11581128
headers = await transport._get_request_headers()
1129+
await client.post("/mcp", headers=headers)
11591130

11601131
# Verify the Authorization header is correctly formatted
1161-
assert headers["Authorization"] == "Bearer test-token-123"
1162-
1163-
1164-
@pytest.mark.anyio
1165-
async def test_streamablehttp_client_token_refresh(basic_server, basic_server_url):
1166-
"""Test that expired tokens are automatically refreshed."""
1167-
# Create a mock OAuth provider that returns an expired token
1168-
mock_provider = AsyncMock(spec=OAuthClientProvider)
1169-
mock_provider.get_token.return_value = OAuthToken(
1170-
access_token="expired-token-123",
1171-
token_type="bearer",
1172-
expires_in=0, # Expired token
1173-
refresh_token="refresh-token-123",
1174-
)
1175-
1176-
# Mock the refresh response
1177-
refreshed_token = OAuthToken(
1178-
access_token="new-token-456",
1179-
token_type="bearer",
1180-
expires_in=3600,
1181-
refresh_token="new-refresh-token-456",
1182-
)
1183-
1184-
# Create client with auth provider
1185-
async with streamablehttp_client(
1186-
f"{basic_server_url}/mcp",
1187-
auth_provider=mock_provider,
1188-
) as (
1189-
read_stream,
1190-
write_stream,
1191-
_,
1192-
_,
1193-
):
1194-
async with ClientSession(read_stream, write_stream) as session:
1195-
# Mock the OAuthAuthorization class to verify refresh token usage
1196-
with patch("mcp.client.streamable_http.OAuthAuthorization") as mock_auth:
1197-
# Configure the mock to return our refreshed token
1198-
mock_auth_instance = AsyncMock()
1199-
mock_auth_instance.authorize.return_value = refreshed_token
1200-
mock_auth.return_value = mock_auth_instance
1201-
1202-
# Initialize the session
1203-
result = await session.initialize()
1204-
assert isinstance(result, InitializeResult)
1205-
1206-
# Verify the mock provider was called to get the initial token
1207-
mock_provider.get_token.assert_called()
1208-
1209-
# Verify that authorize was called with the expired token
1210-
mock_auth_instance.authorize.assert_called_once()
1211-
1212-
# Verify the refreshed token was saved back to the provider
1213-
mock_provider.save_token.assert_called_once_with(refreshed_token)
1214-
1215-
# Verify the new token is being used in subsequent requests
1216-
# by checking the Authorization header format
1217-
assert mock_auth_instance.get_token.return_value.access_token == "new-token-456"
1132+
assert captured_headers["authorization"] == "Bearer test-token-123"

0 commit comments

Comments
 (0)