11"""Unit tests for steering handler base class."""
22
3- import warnings
43from unittest .mock import AsyncMock , Mock
54
65import pytest
1514class TestSteeringHandler (SteeringHandler ):
1615 """Test implementation of SteeringHandler."""
1716
18- async def steer_before_tool (self , agent , tool_use , ** kwargs ):
17+ async def steer_before_tool (self , * , agent , tool_use , ** kwargs ):
1918 return Proceed (reason = "Test proceed" )
2019
2120
@@ -66,7 +65,7 @@ async def test_proceed_action_flow():
6665 """Test complete flow with Proceed action."""
6766
6867 class ProceedHandler (SteeringHandler ):
69- async def steer_before_tool (self , agent , tool_use , ** kwargs ):
68+ async def steer_before_tool (self , * , agent , tool_use , ** kwargs ):
7069 return Proceed (reason = "Test proceed" )
7170
7271 handler = ProceedHandler ()
@@ -85,7 +84,7 @@ async def test_guide_action_flow():
8584 """Test complete flow with Guide action."""
8685
8786 class GuideHandler (SteeringHandler ):
88- async def steer_before_tool (self , agent , tool_use , ** kwargs ):
87+ async def steer_before_tool (self , * , agent , tool_use , ** kwargs ):
8988 return Guide (reason = "Test guidance" )
9089
9190 handler = GuideHandler ()
@@ -105,7 +104,7 @@ async def test_interrupt_action_approved_flow():
105104 """Test complete flow with Interrupt action when approved."""
106105
107106 class InterruptHandler (SteeringHandler ):
108- async def steer_before_tool (self , agent , tool_use , ** kwargs ):
107+ async def steer_before_tool (self , * , agent , tool_use , ** kwargs ):
109108 return Interrupt (reason = "Need approval" )
110109
111110 handler = InterruptHandler ()
@@ -124,7 +123,7 @@ async def test_interrupt_action_denied_flow():
124123 """Test complete flow with Interrupt action when denied."""
125124
126125 class InterruptHandler (SteeringHandler ):
127- async def steer_before_tool (self , agent , tool_use , ** kwargs ):
126+ async def steer_before_tool (self , * , agent , tool_use , ** kwargs ):
128127 return Interrupt (reason = "Need approval" )
129128
130129 handler = InterruptHandler ()
@@ -144,7 +143,7 @@ async def test_unknown_action_flow():
144143 """Test complete flow with unknown action type raises error."""
145144
146145 class UnknownActionHandler (SteeringHandler ):
147- async def steer_before_tool (self , agent , tool_use , ** kwargs ):
146+ async def steer_before_tool (self , * , agent , tool_use , ** kwargs ):
148147 return Mock () # Not a valid SteeringAction
149148
150149 handler = UnknownActionHandler ()
@@ -160,7 +159,7 @@ def test_register_steering_hooks_override():
160159 """Test that _register_steering_hooks can be overridden."""
161160
162161 class CustomHandler (SteeringHandler ):
163- async def steer_before_tool (self , agent , tool_use , ** kwargs ):
162+ async def steer_before_tool (self , * , agent , tool_use , ** kwargs ):
164163 return Proceed (reason = "Custom" )
165164
166165 def register_hooks (self , registry , ** kwargs ):
@@ -201,7 +200,7 @@ def __init__(self, context_callbacks=None):
201200 providers = [MockContextProvider (context_callbacks )] if context_callbacks else None
202201 super ().__init__ (context_providers = providers )
203202
204- async def steer_before_tool (self , agent , tool_use , ** kwargs ):
203+ async def steer_before_tool (self , * , agent , tool_use , ** kwargs ):
205204 return Proceed (reason = "Test proceed" )
206205
207206
@@ -285,7 +284,7 @@ async def test_model_steering_proceed_action_flow():
285284 """Test model steering with Proceed action."""
286285
287286 class ModelProceedHandler (SteeringHandler ):
288- async def steer_after_model (self , agent , message , stop_reason , ** kwargs ):
287+ async def steer_after_model (self , * , agent , message , stop_reason , ** kwargs ):
289288 return Proceed (reason = "Model response accepted" )
290289
291290 handler = ModelProceedHandler ()
@@ -309,7 +308,7 @@ async def test_model_steering_guide_action_flow():
309308 """Test model steering with Guide action sets retry and adds message."""
310309
311310 class ModelGuideHandler (SteeringHandler ):
312- async def steer_after_model (self , agent , message , stop_reason , ** kwargs ):
311+ async def steer_after_model (self , * , agent , message , stop_reason , ** kwargs ):
313312 return Guide (reason = "Please improve your response" )
314313
315314 handler = ModelGuideHandler ()
@@ -342,7 +341,7 @@ def __init__(self):
342341 super ().__init__ ()
343342 self .steer_called = False
344343
345- async def steer_after_model (self , agent , message , stop_reason , ** kwargs ):
344+ async def steer_after_model (self , * , agent , message , stop_reason , ** kwargs ):
346345 self .steer_called = True
347346 return Proceed (reason = "Should not be called" )
348347
@@ -361,7 +360,7 @@ async def test_model_steering_unknown_action_raises_error():
361360 """Test model steering with unknown action type raises error."""
362361
363362 class UnknownModelActionHandler (SteeringHandler ):
364- async def steer_after_model (self , agent , message , stop_reason , ** kwargs ):
363+ async def steer_after_model (self , * , agent , message , stop_reason , ** kwargs ):
365364 return Mock () # Not a valid ModelSteeringAction
366365
367366 handler = UnknownModelActionHandler ()
@@ -382,7 +381,7 @@ async def test_model_steering_exception_handling():
382381 """Test model steering handles exceptions gracefully."""
383382
384383 class ExceptionModelHandler (SteeringHandler ):
385- async def steer_after_model (self , agent , message , stop_reason , ** kwargs ):
384+ async def steer_after_model (self , * , agent , message , stop_reason , ** kwargs ):
386385 raise RuntimeError ("Test exception" )
387386
388387 handler = ExceptionModelHandler ()
@@ -407,7 +406,7 @@ async def test_tool_steering_exception_handling():
407406 """Test tool steering handles exceptions gracefully."""
408407
409408 class ExceptionToolHandler (SteeringHandler ):
410- async def steer_before_tool (self , agent , tool_use , ** kwargs ):
409+ async def steer_before_tool (self , * , agent , tool_use , ** kwargs ):
411410 raise RuntimeError ("Test exception" )
412411
413412 handler = ExceptionToolHandler ()
@@ -431,7 +430,7 @@ async def test_default_steer_before_tool_returns_proceed():
431430 tool_use = {"name" : "test_tool" }
432431
433432 # Call the parent's default implementation
434- result = await SteeringHandler .steer_before_tool (handler , agent , tool_use )
433+ result = await SteeringHandler .steer_before_tool (handler , agent = agent , tool_use = tool_use )
435434
436435 assert isinstance (result , Proceed )
437436 assert "Default implementation" in result .reason
@@ -446,30 +445,12 @@ async def test_default_steer_after_model_returns_proceed():
446445 stop_reason = "end_turn"
447446
448447 # Call the parent's default implementation
449- result = await SteeringHandler .steer_after_model (handler , agent , message , stop_reason )
448+ result = await SteeringHandler .steer_after_model (handler , agent = agent , message = message , stop_reason = stop_reason )
450449
451450 assert isinstance (result , Proceed )
452451 assert "Default implementation" in result .reason
453452
454453
455- # Deprecated steer() method test
456- @pytest .mark .asyncio
457- async def test_deprecated_steer_method_emits_warning ():
458- """Test deprecated steer() method emits DeprecationWarning."""
459- handler = TestSteeringHandler ()
460- agent = Mock ()
461- tool_use = {"name" : "test_tool" }
462-
463- with warnings .catch_warnings (record = True ) as w :
464- warnings .simplefilter ("always" )
465- result = await handler .steer (agent , tool_use )
466-
467- assert len (w ) == 1
468- assert issubclass (w [0 ].category , DeprecationWarning )
469- assert "steer() is deprecated" in str (w [0 ].message )
470- assert isinstance (result , Proceed )
471-
472-
473454def test_register_hooks_registers_model_steering ():
474455 """Test that register_hooks registers model steering callback."""
475456 handler = TestSteeringHandler ()
0 commit comments