@@ -391,6 +391,124 @@ async def test_execute_protected_statement_cluster_not_in_list(self, mocker):
391391 'target-cluster' , 'test-db' , 'SELECT 1' , allow_read_write = False
392392 )
393393
394+ @pytest .mark .asyncio
395+ async def test_execute_protected_statement_user_sql_fails_end_succeeds (self , mocker ):
396+ """Test user SQL fails but END succeeds - should raise user SQL error."""
397+ # Mock discover_clusters
398+ mock_discover_clusters = mocker .patch (
399+ 'awslabs.redshift_mcp_server.redshift.discover_clusters'
400+ )
401+ mock_discover_clusters .return_value = [
402+ {'identifier' : 'test-cluster' , 'type' : 'provisioned' }
403+ ]
404+
405+ # Mock session manager
406+ mock_session_manager = mocker .patch ('awslabs.redshift_mcp_server.redshift.session_manager' )
407+ mock_session_manager .session = mocker .AsyncMock (return_value = 'session-123' )
408+
409+ # Mock _execute_statement to fail for user SQL, succeed for BEGIN and END
410+ mock_execute_statement = mocker .patch (
411+ 'awslabs.redshift_mcp_server.redshift._execute_statement'
412+ )
413+
414+ def execute_side_effect (cluster_info , cluster_identifier , database_name , sql , ** kwargs ):
415+ if sql == 'BEGIN READ ONLY;' :
416+ return 'begin-stmt-id'
417+ elif sql == 'SELECT invalid_syntax' :
418+ raise Exception ('SQL syntax error' )
419+ elif sql == 'END;' :
420+ return 'end-stmt-id'
421+ return 'stmt-id'
422+
423+ mock_execute_statement .side_effect = execute_side_effect
424+
425+ with pytest .raises (Exception , match = 'SQL syntax error' ):
426+ await _execute_protected_statement (
427+ 'test-cluster' , 'test-db' , 'SELECT invalid_syntax' , allow_read_write = False
428+ )
429+
430+ # Verify END was still called
431+ assert mock_execute_statement .call_count == 3
432+ calls = mock_execute_statement .call_args_list
433+ assert calls [0 ][1 ]['sql' ] == 'BEGIN READ ONLY;'
434+ assert calls [1 ][1 ]['sql' ] == 'SELECT invalid_syntax'
435+ assert calls [2 ][1 ]['sql' ] == 'END;'
436+
437+ @pytest .mark .asyncio
438+ async def test_execute_protected_statement_user_sql_succeeds_end_fails (self , mocker ):
439+ """Test user SQL succeeds but END fails - should raise END error."""
440+ # Mock discover_clusters
441+ mock_discover_clusters = mocker .patch (
442+ 'awslabs.redshift_mcp_server.redshift.discover_clusters'
443+ )
444+ mock_discover_clusters .return_value = [
445+ {'identifier' : 'test-cluster' , 'type' : 'provisioned' }
446+ ]
447+
448+ # Mock session manager
449+ mock_session_manager = mocker .patch ('awslabs.redshift_mcp_server.redshift.session_manager' )
450+ mock_session_manager .session = mocker .AsyncMock (return_value = 'session-123' )
451+
452+ # Mock _execute_statement to succeed for user SQL, fail for END
453+ mock_execute_statement = mocker .patch (
454+ 'awslabs.redshift_mcp_server.redshift._execute_statement'
455+ )
456+
457+ def execute_side_effect (cluster_info , cluster_identifier , database_name , sql , ** kwargs ):
458+ if sql == 'BEGIN READ ONLY;' :
459+ return 'begin-stmt-id'
460+ elif sql == 'SELECT 1' :
461+ return 'user-stmt-id'
462+ elif sql == 'END;' :
463+ raise Exception ('END statement failed' )
464+ return 'stmt-id'
465+
466+ mock_execute_statement .side_effect = execute_side_effect
467+
468+ with pytest .raises (Exception , match = 'END statement failed' ):
469+ await _execute_protected_statement (
470+ 'test-cluster' , 'test-db' , 'SELECT 1' , allow_read_write = False
471+ )
472+
473+ @pytest .mark .asyncio
474+ async def test_execute_protected_statement_both_user_sql_and_end_fail (self , mocker ):
475+ """Test both user SQL and END fail - should raise combined error."""
476+ # Mock discover_clusters
477+ mock_discover_clusters = mocker .patch (
478+ 'awslabs.redshift_mcp_server.redshift.discover_clusters'
479+ )
480+ mock_discover_clusters .return_value = [
481+ {'identifier' : 'test-cluster' , 'type' : 'provisioned' }
482+ ]
483+
484+ # Mock session manager
485+ mock_session_manager = mocker .patch ('awslabs.redshift_mcp_server.redshift.session_manager' )
486+ mock_session_manager .session = mocker .AsyncMock (return_value = 'session-123' )
487+
488+ # Mock _execute_statement to fail for both user SQL and END
489+ mock_execute_statement = mocker .patch (
490+ 'awslabs.redshift_mcp_server.redshift._execute_statement'
491+ )
492+
493+ def execute_side_effect (cluster_info , cluster_identifier , database_name , sql , ** kwargs ):
494+ if sql == 'BEGIN READ ONLY;' :
495+ return 'begin-stmt-id'
496+ elif sql == 'SELECT invalid_syntax' :
497+ raise Exception ('SQL syntax error' )
498+ elif sql == 'END;' :
499+ raise Exception ('END statement failed' )
500+ return 'stmt-id'
501+
502+ mock_execute_statement .side_effect = execute_side_effect
503+
504+ with pytest .raises (
505+ Exception ,
506+ match = 'User SQL failed: SQL syntax error; END statement failed: END statement failed' ,
507+ ):
508+ await _execute_protected_statement (
509+ 'test-cluster' , 'test-db' , 'SELECT invalid_syntax' , allow_read_write = False
510+ )
511+
394512
395513class TestExecuteStatement :
396514 """Tests for _execute_statement function."""
0 commit comments