Skip to content

Commit f8ac8dc

Browse files
grayhempSergey Konoplev
andauthored
fix(redshift-mcp-server): Ensure transaction cleanup when protecting statement (#1512)
Fixes issue where failed user SQL left transactions open by preventing END statement execution. - Make sure to always execute END statement - Log and raise both user SQL and cleanup errors, including the combined error - Add unit tests for transaction error handling scenarios Co-authored-by: Sergey Konoplev <[email protected]>
1 parent 946b346 commit f8ac8dc

File tree

2 files changed

+158
-17
lines changed

2 files changed

+158
-17
lines changed

src/redshift-mcp-server/awslabs/redshift_mcp_server/redshift.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -263,27 +263,50 @@ async def _execute_protected_statement(
263263
session_id=session_id,
264264
)
265265

266-
# Execute user SQL with parameters
267-
user_query_id = await _execute_statement(
268-
cluster_info=cluster_info,
269-
cluster_identifier=cluster_identifier,
270-
database_name=database_name,
271-
sql=sql,
272-
parameters=parameters,
273-
session_id=session_id,
274-
)
266+
# Execute user SQL with parameters, ensuring transaction is always closed
267+
user_query_id = None
268+
user_sql_error = None
275269

276-
# Execute END statement to close transaction
277-
await _execute_statement(
278-
cluster_info=cluster_info,
279-
cluster_identifier=cluster_identifier,
280-
database_name=database_name,
281-
sql='END;',
282-
session_id=session_id,
283-
)
270+
try:
271+
user_query_id = await _execute_statement(
272+
cluster_info=cluster_info,
273+
cluster_identifier=cluster_identifier,
274+
database_name=database_name,
275+
sql=sql,
276+
parameters=parameters,
277+
session_id=session_id,
278+
)
279+
except Exception as e:
280+
user_sql_error = e
281+
logger.error(f'User SQL execution failed: {e}')
282+
283+
# Always execute END statement to close transaction
284+
try:
285+
await _execute_statement(
286+
cluster_info=cluster_info,
287+
cluster_identifier=cluster_identifier,
288+
database_name=database_name,
289+
sql='END;',
290+
session_id=session_id,
291+
)
292+
except Exception as end_error:
293+
logger.error(f'END statement execution failed: {end_error}')
294+
if user_sql_error:
295+
# Both failed - raise combined error
296+
raise Exception(
297+
f'User SQL failed: {user_sql_error}; END statement failed: {end_error}'
298+
)
299+
else:
300+
# Only END failed
301+
raise end_error
302+
303+
# If user SQL failed but END succeeded, raise user SQL error
304+
if user_sql_error:
305+
raise user_sql_error
284306

285307
# Get results from user query
286308
data_client = client_manager.redshift_data_client()
309+
assert user_query_id is not None, 'user_query_id should not be None at this point'
287310
results_response = data_client.get_statement_result(Id=user_query_id)
288311
return results_response, user_query_id
289312

src/redshift-mcp-server/tests/test_redshift.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,124 @@ async def test_execute_protected_statement_cluster_not_in_list(self, mocker):
391391
'target-cluster', 'test-db', 'SELECT 1', allow_read_write=False
392392
)
393393

394+
@pytest.mark.asyncio
395+
async def test_execute_protected_statement_user_sql_fails_end_succeeds(self, mocker):
396+
"""Test user SQL fails but END succeeds - should raise user SQL error."""
397+
# Mock discover_clusters
398+
mock_discover_clusters = mocker.patch(
399+
'awslabs.redshift_mcp_server.redshift.discover_clusters'
400+
)
401+
mock_discover_clusters.return_value = [
402+
{'identifier': 'test-cluster', 'type': 'provisioned'}
403+
]
404+
405+
# Mock session manager
406+
mock_session_manager = mocker.patch('awslabs.redshift_mcp_server.redshift.session_manager')
407+
mock_session_manager.session = mocker.AsyncMock(return_value='session-123')
408+
409+
# Mock _execute_statement to fail for user SQL, succeed for BEGIN and END
410+
mock_execute_statement = mocker.patch(
411+
'awslabs.redshift_mcp_server.redshift._execute_statement'
412+
)
413+
414+
def execute_side_effect(cluster_info, cluster_identifier, database_name, sql, **kwargs):
415+
if sql == 'BEGIN READ ONLY;':
416+
return 'begin-stmt-id'
417+
elif sql == 'SELECT invalid_syntax':
418+
raise Exception('SQL syntax error')
419+
elif sql == 'END;':
420+
return 'end-stmt-id'
421+
return 'stmt-id'
422+
423+
mock_execute_statement.side_effect = execute_side_effect
424+
425+
with pytest.raises(Exception, match='SQL syntax error'):
426+
await _execute_protected_statement(
427+
'test-cluster', 'test-db', 'SELECT invalid_syntax', allow_read_write=False
428+
)
429+
430+
# Verify END was still called
431+
assert mock_execute_statement.call_count == 3
432+
calls = mock_execute_statement.call_args_list
433+
assert calls[0][1]['sql'] == 'BEGIN READ ONLY;'
434+
assert calls[1][1]['sql'] == 'SELECT invalid_syntax'
435+
assert calls[2][1]['sql'] == 'END;'
436+
437+
@pytest.mark.asyncio
438+
async def test_execute_protected_statement_user_sql_succeeds_end_fails(self, mocker):
439+
"""Test user SQL succeeds but END fails - should raise END error."""
440+
# Mock discover_clusters
441+
mock_discover_clusters = mocker.patch(
442+
'awslabs.redshift_mcp_server.redshift.discover_clusters'
443+
)
444+
mock_discover_clusters.return_value = [
445+
{'identifier': 'test-cluster', 'type': 'provisioned'}
446+
]
447+
448+
# Mock session manager
449+
mock_session_manager = mocker.patch('awslabs.redshift_mcp_server.redshift.session_manager')
450+
mock_session_manager.session = mocker.AsyncMock(return_value='session-123')
451+
452+
# Mock _execute_statement to succeed for user SQL, fail for END
453+
mock_execute_statement = mocker.patch(
454+
'awslabs.redshift_mcp_server.redshift._execute_statement'
455+
)
456+
457+
def execute_side_effect(cluster_info, cluster_identifier, database_name, sql, **kwargs):
458+
if sql == 'BEGIN READ ONLY;':
459+
return 'begin-stmt-id'
460+
elif sql == 'SELECT 1':
461+
return 'user-stmt-id'
462+
elif sql == 'END;':
463+
raise Exception('END statement failed')
464+
return 'stmt-id'
465+
466+
mock_execute_statement.side_effect = execute_side_effect
467+
468+
with pytest.raises(Exception, match='END statement failed'):
469+
await _execute_protected_statement(
470+
'test-cluster', 'test-db', 'SELECT 1', allow_read_write=False
471+
)
472+
473+
@pytest.mark.asyncio
474+
async def test_execute_protected_statement_both_user_sql_and_end_fail(self, mocker):
475+
"""Test both user SQL and END fail - should raise combined error."""
476+
# Mock discover_clusters
477+
mock_discover_clusters = mocker.patch(
478+
'awslabs.redshift_mcp_server.redshift.discover_clusters'
479+
)
480+
mock_discover_clusters.return_value = [
481+
{'identifier': 'test-cluster', 'type': 'provisioned'}
482+
]
483+
484+
# Mock session manager
485+
mock_session_manager = mocker.patch('awslabs.redshift_mcp_server.redshift.session_manager')
486+
mock_session_manager.session = mocker.AsyncMock(return_value='session-123')
487+
488+
# Mock _execute_statement to fail for both user SQL and END
489+
mock_execute_statement = mocker.patch(
490+
'awslabs.redshift_mcp_server.redshift._execute_statement'
491+
)
492+
493+
def execute_side_effect(cluster_info, cluster_identifier, database_name, sql, **kwargs):
494+
if sql == 'BEGIN READ ONLY;':
495+
return 'begin-stmt-id'
496+
elif sql == 'SELECT invalid_syntax':
497+
raise Exception('SQL syntax error')
498+
elif sql == 'END;':
499+
raise Exception('END statement failed')
500+
return 'stmt-id'
501+
502+
mock_execute_statement.side_effect = execute_side_effect
503+
504+
with pytest.raises(
505+
Exception,
506+
match='User SQL failed: SQL syntax error; END statement failed: END statement failed',
507+
):
508+
await _execute_protected_statement(
509+
'test-cluster', 'test-db', 'SELECT invalid_syntax', allow_read_write=False
510+
)
511+
394512

395513
class TestExecuteStatement:
396514
"""Tests for _execute_statement function."""

0 commit comments

Comments
 (0)