Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions mcp_proxy_for_aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
from botocore.credentials import Credentials
from contextlib import _AsyncGeneratorContextManager
from datetime import timedelta
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
from mcp.client.streamable_http import GetSessionIdCallback

try:
from mcp.client.streamable_http import streamable_http_client
except ImportError:
from mcp.client.streamable_http import streamablehttp_client as streamable_http_client
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import SessionMessage
from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth
Expand Down Expand Up @@ -114,7 +119,7 @@ def aws_iam_streamablehttp_client(
auth = SigV4HTTPXAuth(creds, aws_service, region)

# Return the streamable HTTP client context manager with AWS IAM authentication
return streamablehttp_client(
return streamable_http_client(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding from this change modelcontextprotocol/python-sdk#1177 is that the signature has changed on how the client is passed.

If streamable_http_client is imported instead of streamablehttp_client, I think this will break because the new implementation will use a default http_client which will not have the headers, auth etc here set up.

url=endpoint,
headers=headers,
timeout=timeout,
Expand Down
16 changes: 8 additions & 8 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def test_boto3_session_parameters(
mock_read, mock_write, mock_get_session = mock_streams

with patch('boto3.Session', return_value=mock_session) as mock_boto:
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
Expand Down Expand Up @@ -94,7 +94,7 @@ async def test_sigv4_auth_is_created_and_used(mock_session, mock_streams, servic

with patch('boto3.Session', return_value=mock_session):
with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth') as mock_auth_cls:
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_auth = Mock()
mock_auth_cls.return_value = mock_auth
mock_stream_client.return_value.__aenter__ = AsyncMock(
Expand Down Expand Up @@ -131,13 +131,13 @@ async def test_sigv4_auth_is_created_and_used(mock_session, mock_streams, servic
async def test_streamable_client_parameters(
mock_session, mock_streams, headers, timeout_value, sse_value, terminate_value
):
"""Test the correctness of streamablehttp_client parameters."""
"""Test the correctness of streamable_http_client parameters."""
# Verify that connection settings are forwarded as-is to the streamable HTTP client.
# timedelta values are allowed and compared directly here.
mock_read, mock_write, mock_get_session = mock_streams

with patch('boto3.Session', return_value=mock_session):
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
Expand Down Expand Up @@ -170,7 +170,7 @@ async def test_custom_httpx_client_factory_is_passed(mock_session, mock_streams)
custom_factory = Mock()

with patch('boto3.Session', return_value=mock_session):
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
Expand Down Expand Up @@ -198,7 +198,7 @@ async def mock_aexit(*_):
cleanup_called = True

with patch('boto3.Session', return_value=mock_session):
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
Expand All @@ -220,7 +220,7 @@ async def test_credentials_parameter_with_region(mock_streams):
creds = Credentials('test_key', 'test_secret', 'test_token')

with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth') as mock_auth_cls:
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_auth = Mock()
mock_auth_cls.return_value = mock_auth
mock_stream_client.return_value.__aenter__ = AsyncMock(
Expand Down Expand Up @@ -264,7 +264,7 @@ async def test_credentials_parameter_bypasses_boto3_session(mock_streams):

with patch('boto3.Session') as mock_boto:
with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth'):
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client:
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
Expand Down
Loading