diff --git a/mcp_proxy_for_aws/client.py b/mcp_proxy_for_aws/client.py index efdb6bd..f54f255 100644 --- a/mcp_proxy_for_aws/client.py +++ b/mcp_proxy_for_aws/client.py @@ -18,7 +18,12 @@ from botocore.credentials import Credentials from contextlib import _AsyncGeneratorContextManager from datetime import timedelta -from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client +from mcp.client.streamable_http import GetSessionIdCallback + +try: + from mcp.client.streamable_http import streamable_http_client +except ImportError: + from mcp.client.streamable_http import streamablehttp_client as streamable_http_client from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client from mcp.shared.message import SessionMessage from mcp_proxy_for_aws.sigv4_helper import SigV4HTTPXAuth @@ -114,7 +119,7 @@ def aws_iam_streamablehttp_client( auth = SigV4HTTPXAuth(creds, aws_service, region) # Return the streamable HTTP client context manager with AWS IAM authentication - return streamablehttp_client( + return streamable_http_client( url=endpoint, headers=headers, timeout=timeout, diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 2960e7d..4eda527 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -60,7 +60,7 @@ async def test_boto3_session_parameters( mock_read, mock_write, mock_get_session = mock_streams with patch('boto3.Session', return_value=mock_session) as mock_boto: - with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client: + with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client: mock_stream_client.return_value.__aenter__ = AsyncMock( return_value=(mock_read, mock_write, mock_get_session) ) @@ -94,7 +94,7 @@ async def test_sigv4_auth_is_created_and_used(mock_session, mock_streams, servic with patch('boto3.Session', return_value=mock_session): with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth') as mock_auth_cls: - with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client: + with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client: mock_auth = Mock() mock_auth_cls.return_value = mock_auth mock_stream_client.return_value.__aenter__ = AsyncMock( @@ -131,13 +131,13 @@ async def test_sigv4_auth_is_created_and_used(mock_session, mock_streams, servic async def test_streamable_client_parameters( mock_session, mock_streams, headers, timeout_value, sse_value, terminate_value ): - """Test the correctness of streamablehttp_client parameters.""" + """Test the correctness of streamable_http_client parameters.""" # Verify that connection settings are forwarded as-is to the streamable HTTP client. # timedelta values are allowed and compared directly here. mock_read, mock_write, mock_get_session = mock_streams with patch('boto3.Session', return_value=mock_session): - with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client: + with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client: mock_stream_client.return_value.__aenter__ = AsyncMock( return_value=(mock_read, mock_write, mock_get_session) ) @@ -170,7 +170,7 @@ async def test_custom_httpx_client_factory_is_passed(mock_session, mock_streams) custom_factory = Mock() with patch('boto3.Session', return_value=mock_session): - with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client: + with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client: mock_stream_client.return_value.__aenter__ = AsyncMock( return_value=(mock_read, mock_write, mock_get_session) ) @@ -198,7 +198,7 @@ async def mock_aexit(*_): cleanup_called = True with patch('boto3.Session', return_value=mock_session): - with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client: + with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client: mock_stream_client.return_value.__aenter__ = AsyncMock( return_value=(mock_read, mock_write, mock_get_session) ) @@ -220,7 +220,7 @@ async def test_credentials_parameter_with_region(mock_streams): creds = Credentials('test_key', 'test_secret', 'test_token') with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth') as mock_auth_cls: - with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client: + with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client: mock_auth = Mock() mock_auth_cls.return_value = mock_auth mock_stream_client.return_value.__aenter__ = AsyncMock( @@ -264,7 +264,7 @@ async def test_credentials_parameter_bypasses_boto3_session(mock_streams): with patch('boto3.Session') as mock_boto: with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth'): - with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client: + with patch('mcp_proxy_for_aws.client.streamable_http_client') as mock_stream_client: mock_stream_client.return_value.__aenter__ = AsyncMock( return_value=(mock_read, mock_write, mock_get_session) )