From 0e9b5919ff7e67ae9ff3c4a9342081902fec09ee Mon Sep 17 00:00:00 2001 From: Sergey Konoplev Date: Mon, 13 Oct 2025 23:38:46 +0000 Subject: [PATCH] fix(redshift-mcp-server): Ensure transaction cleanup when protecting statement Fixes issue where failed user SQL left transactions open by preventing END statement execution. - Make sure to always execute END statement - Log and raise both user SQL and cleanup errors, including the combined error - Add unit tests for transaction error handling scenarios --- .../awslabs/redshift_mcp_server/redshift.py | 57 ++++++--- .../tests/test_redshift.py | 118 ++++++++++++++++++ 2 files changed, 158 insertions(+), 17 deletions(-) diff --git a/src/redshift-mcp-server/awslabs/redshift_mcp_server/redshift.py b/src/redshift-mcp-server/awslabs/redshift_mcp_server/redshift.py index 51d5a9bdcb..2d1a96c8ae 100644 --- a/src/redshift-mcp-server/awslabs/redshift_mcp_server/redshift.py +++ b/src/redshift-mcp-server/awslabs/redshift_mcp_server/redshift.py @@ -263,27 +263,50 @@ async def _execute_protected_statement( session_id=session_id, ) - # Execute user SQL with parameters - user_query_id = await _execute_statement( - cluster_info=cluster_info, - cluster_identifier=cluster_identifier, - database_name=database_name, - sql=sql, - parameters=parameters, - session_id=session_id, - ) + # Execute user SQL with parameters, ensuring transaction is always closed + user_query_id = None + user_sql_error = None - # Execute END statement to close transaction - await _execute_statement( - cluster_info=cluster_info, - cluster_identifier=cluster_identifier, - database_name=database_name, - sql='END;', - session_id=session_id, - ) + try: + user_query_id = await _execute_statement( + cluster_info=cluster_info, + cluster_identifier=cluster_identifier, + database_name=database_name, + sql=sql, + parameters=parameters, + session_id=session_id, + ) + except Exception as e: + user_sql_error = e + logger.error(f'User SQL execution failed: {e}') + + # Always execute END statement to close transaction + try: + await _execute_statement( + cluster_info=cluster_info, + cluster_identifier=cluster_identifier, + database_name=database_name, + sql='END;', + session_id=session_id, + ) + except Exception as end_error: + logger.error(f'END statement execution failed: {end_error}') + if user_sql_error: + # Both failed - raise combined error + raise Exception( + f'User SQL failed: {user_sql_error}; END statement failed: {end_error}' + ) + else: + # Only END failed + raise end_error + + # If user SQL failed but END succeeded, raise user SQL error + if user_sql_error: + raise user_sql_error # Get results from user query data_client = client_manager.redshift_data_client() + assert user_query_id is not None, 'user_query_id should not be None at this point' results_response = data_client.get_statement_result(Id=user_query_id) return results_response, user_query_id diff --git a/src/redshift-mcp-server/tests/test_redshift.py b/src/redshift-mcp-server/tests/test_redshift.py index 4bc9596db9..ea2685691a 100644 --- a/src/redshift-mcp-server/tests/test_redshift.py +++ b/src/redshift-mcp-server/tests/test_redshift.py @@ -391,6 +391,124 @@ async def test_execute_protected_statement_cluster_not_in_list(self, mocker): 'target-cluster', 'test-db', 'SELECT 1', allow_read_write=False ) + @pytest.mark.asyncio + async def test_execute_protected_statement_user_sql_fails_end_succeeds(self, mocker): + """Test user SQL fails but END succeeds - should raise user SQL error.""" + # Mock discover_clusters + mock_discover_clusters = mocker.patch( + 'awslabs.redshift_mcp_server.redshift.discover_clusters' + ) + mock_discover_clusters.return_value = [ + {'identifier': 'test-cluster', 'type': 'provisioned'} + ] + + # Mock session manager + mock_session_manager = mocker.patch('awslabs.redshift_mcp_server.redshift.session_manager') + mock_session_manager.session = mocker.AsyncMock(return_value='session-123') + + # Mock _execute_statement to fail for user SQL, succeed for BEGIN and END + mock_execute_statement = mocker.patch( + 'awslabs.redshift_mcp_server.redshift._execute_statement' + ) + + def execute_side_effect(cluster_info, cluster_identifier, database_name, sql, **kwargs): + if sql == 'BEGIN READ ONLY;': + return 'begin-stmt-id' + elif sql == 'SELECT invalid_syntax': + raise Exception('SQL syntax error') + elif sql == 'END;': + return 'end-stmt-id' + return 'stmt-id' + + mock_execute_statement.side_effect = execute_side_effect + + with pytest.raises(Exception, match='SQL syntax error'): + await _execute_protected_statement( + 'test-cluster', 'test-db', 'SELECT invalid_syntax', allow_read_write=False + ) + + # Verify END was still called + assert mock_execute_statement.call_count == 3 + calls = mock_execute_statement.call_args_list + assert calls[0][1]['sql'] == 'BEGIN READ ONLY;' + assert calls[1][1]['sql'] == 'SELECT invalid_syntax' + assert calls[2][1]['sql'] == 'END;' + + @pytest.mark.asyncio + async def test_execute_protected_statement_user_sql_succeeds_end_fails(self, mocker): + """Test user SQL succeeds but END fails - should raise END error.""" + # Mock discover_clusters + mock_discover_clusters = mocker.patch( + 'awslabs.redshift_mcp_server.redshift.discover_clusters' + ) + mock_discover_clusters.return_value = [ + {'identifier': 'test-cluster', 'type': 'provisioned'} + ] + + # Mock session manager + mock_session_manager = mocker.patch('awslabs.redshift_mcp_server.redshift.session_manager') + mock_session_manager.session = mocker.AsyncMock(return_value='session-123') + + # Mock _execute_statement to succeed for user SQL, fail for END + mock_execute_statement = mocker.patch( + 'awslabs.redshift_mcp_server.redshift._execute_statement' + ) + + def execute_side_effect(cluster_info, cluster_identifier, database_name, sql, **kwargs): + if sql == 'BEGIN READ ONLY;': + return 'begin-stmt-id' + elif sql == 'SELECT 1': + return 'user-stmt-id' + elif sql == 'END;': + raise Exception('END statement failed') + return 'stmt-id' + + mock_execute_statement.side_effect = execute_side_effect + + with pytest.raises(Exception, match='END statement failed'): + await _execute_protected_statement( + 'test-cluster', 'test-db', 'SELECT 1', allow_read_write=False + ) + + @pytest.mark.asyncio + async def test_execute_protected_statement_both_user_sql_and_end_fail(self, mocker): + """Test both user SQL and END fail - should raise combined error.""" + # Mock discover_clusters + mock_discover_clusters = mocker.patch( + 'awslabs.redshift_mcp_server.redshift.discover_clusters' + ) + mock_discover_clusters.return_value = [ + {'identifier': 'test-cluster', 'type': 'provisioned'} + ] + + # Mock session manager + mock_session_manager = mocker.patch('awslabs.redshift_mcp_server.redshift.session_manager') + mock_session_manager.session = mocker.AsyncMock(return_value='session-123') + + # Mock _execute_statement to fail for both user SQL and END + mock_execute_statement = mocker.patch( + 'awslabs.redshift_mcp_server.redshift._execute_statement' + ) + + def execute_side_effect(cluster_info, cluster_identifier, database_name, sql, **kwargs): + if sql == 'BEGIN READ ONLY;': + return 'begin-stmt-id' + elif sql == 'SELECT invalid_syntax': + raise Exception('SQL syntax error') + elif sql == 'END;': + raise Exception('END statement failed') + return 'stmt-id' + + mock_execute_statement.side_effect = execute_side_effect + + with pytest.raises( + Exception, + match='User SQL failed: SQL syntax error; END statement failed: END statement failed', + ): + await _execute_protected_statement( + 'test-cluster', 'test-db', 'SELECT invalid_syntax', allow_read_write=False + ) + class TestExecuteStatement: """Tests for _execute_statement function."""