Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions src/redshift-mcp-server/awslabs/redshift_mcp_server/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,27 +263,50 @@ async def _execute_protected_statement(
session_id=session_id,
)

# Execute user SQL with parameters
user_query_id = await _execute_statement(
cluster_info=cluster_info,
cluster_identifier=cluster_identifier,
database_name=database_name,
sql=sql,
parameters=parameters,
session_id=session_id,
)
# Execute user SQL with parameters, ensuring transaction is always closed
user_query_id = None
user_sql_error = None

# Execute END statement to close transaction
await _execute_statement(
cluster_info=cluster_info,
cluster_identifier=cluster_identifier,
database_name=database_name,
sql='END;',
session_id=session_id,
)
try:
user_query_id = await _execute_statement(
cluster_info=cluster_info,
cluster_identifier=cluster_identifier,
database_name=database_name,
sql=sql,
parameters=parameters,
session_id=session_id,
)
except Exception as e:
user_sql_error = e
logger.error(f'User SQL execution failed: {e}')

# Always execute END statement to close transaction
try:
await _execute_statement(
cluster_info=cluster_info,
cluster_identifier=cluster_identifier,
database_name=database_name,
sql='END;',
session_id=session_id,
)
except Exception as end_error:
logger.error(f'END statement execution failed: {end_error}')
if user_sql_error:
# Both failed - raise combined error
raise Exception(
f'User SQL failed: {user_sql_error}; END statement failed: {end_error}'
)
else:
# Only END failed
raise end_error

# If user SQL failed but END succeeded, raise user SQL error
if user_sql_error:
raise user_sql_error

# Get results from user query
data_client = client_manager.redshift_data_client()
assert user_query_id is not None, 'user_query_id should not be None at this point'
results_response = data_client.get_statement_result(Id=user_query_id)
return results_response, user_query_id

Expand Down
118 changes: 118 additions & 0 deletions src/redshift-mcp-server/tests/test_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,124 @@ async def test_execute_protected_statement_cluster_not_in_list(self, mocker):
'target-cluster', 'test-db', 'SELECT 1', allow_read_write=False
)

@pytest.mark.asyncio
async def test_execute_protected_statement_user_sql_fails_end_succeeds(self, mocker):
"""Test user SQL fails but END succeeds - should raise user SQL error."""
# Mock discover_clusters
mock_discover_clusters = mocker.patch(
'awslabs.redshift_mcp_server.redshift.discover_clusters'
)
mock_discover_clusters.return_value = [
{'identifier': 'test-cluster', 'type': 'provisioned'}
]

# Mock session manager
mock_session_manager = mocker.patch('awslabs.redshift_mcp_server.redshift.session_manager')
mock_session_manager.session = mocker.AsyncMock(return_value='session-123')

# Mock _execute_statement to fail for user SQL, succeed for BEGIN and END
mock_execute_statement = mocker.patch(
'awslabs.redshift_mcp_server.redshift._execute_statement'
)

def execute_side_effect(cluster_info, cluster_identifier, database_name, sql, **kwargs):
if sql == 'BEGIN READ ONLY;':
return 'begin-stmt-id'
elif sql == 'SELECT invalid_syntax':
raise Exception('SQL syntax error')
elif sql == 'END;':
return 'end-stmt-id'
return 'stmt-id'

mock_execute_statement.side_effect = execute_side_effect

with pytest.raises(Exception, match='SQL syntax error'):
await _execute_protected_statement(
'test-cluster', 'test-db', 'SELECT invalid_syntax', allow_read_write=False
)

# Verify END was still called
assert mock_execute_statement.call_count == 3
calls = mock_execute_statement.call_args_list
assert calls[0][1]['sql'] == 'BEGIN READ ONLY;'
assert calls[1][1]['sql'] == 'SELECT invalid_syntax'
assert calls[2][1]['sql'] == 'END;'

@pytest.mark.asyncio
async def test_execute_protected_statement_user_sql_succeeds_end_fails(self, mocker):
"""Test user SQL succeeds but END fails - should raise END error."""
# Mock discover_clusters
mock_discover_clusters = mocker.patch(
'awslabs.redshift_mcp_server.redshift.discover_clusters'
)
mock_discover_clusters.return_value = [
{'identifier': 'test-cluster', 'type': 'provisioned'}
]

# Mock session manager
mock_session_manager = mocker.patch('awslabs.redshift_mcp_server.redshift.session_manager')
mock_session_manager.session = mocker.AsyncMock(return_value='session-123')

# Mock _execute_statement to succeed for user SQL, fail for END
mock_execute_statement = mocker.patch(
'awslabs.redshift_mcp_server.redshift._execute_statement'
)

def execute_side_effect(cluster_info, cluster_identifier, database_name, sql, **kwargs):
if sql == 'BEGIN READ ONLY;':
return 'begin-stmt-id'
elif sql == 'SELECT 1':
return 'user-stmt-id'
elif sql == 'END;':
raise Exception('END statement failed')
return 'stmt-id'

mock_execute_statement.side_effect = execute_side_effect

with pytest.raises(Exception, match='END statement failed'):
await _execute_protected_statement(
'test-cluster', 'test-db', 'SELECT 1', allow_read_write=False
)

@pytest.mark.asyncio
async def test_execute_protected_statement_both_user_sql_and_end_fail(self, mocker):
"""Test both user SQL and END fail - should raise combined error."""
# Mock discover_clusters
mock_discover_clusters = mocker.patch(
'awslabs.redshift_mcp_server.redshift.discover_clusters'
)
mock_discover_clusters.return_value = [
{'identifier': 'test-cluster', 'type': 'provisioned'}
]

# Mock session manager
mock_session_manager = mocker.patch('awslabs.redshift_mcp_server.redshift.session_manager')
mock_session_manager.session = mocker.AsyncMock(return_value='session-123')

# Mock _execute_statement to fail for both user SQL and END
mock_execute_statement = mocker.patch(
'awslabs.redshift_mcp_server.redshift._execute_statement'
)

def execute_side_effect(cluster_info, cluster_identifier, database_name, sql, **kwargs):
if sql == 'BEGIN READ ONLY;':
return 'begin-stmt-id'
elif sql == 'SELECT invalid_syntax':
raise Exception('SQL syntax error')
elif sql == 'END;':
raise Exception('END statement failed')
return 'stmt-id'

mock_execute_statement.side_effect = execute_side_effect

with pytest.raises(
Exception,
match='User SQL failed: SQL syntax error; END statement failed: END statement failed',
):
await _execute_protected_statement(
'test-cluster', 'test-db', 'SELECT invalid_syntax', allow_read_write=False
)


class TestExecuteStatement:
"""Tests for _execute_statement function."""
Expand Down