@@ -5377,3 +5377,137 @@ def dynamic_instr() -> str:
53775377 sys_texts = [p .content for p in req .parts if isinstance (p , SystemPromptPart )]
53785378 # The dynamic system prompt should still be present since overrides target instructions only
53795379 assert dynamic_value in sys_texts
5380+
5381+
5382+ def test_continue_conversation_that_ended_in_output_tool_call (allow_model_requests : None ):
5383+ def llm (messages : list [ModelMessage ], info : AgentInfo ) -> ModelResponse :
5384+ if any (isinstance (p , ToolReturnPart ) and p .tool_name == 'roll_dice' for p in messages [- 1 ].parts ):
5385+ return ModelResponse (
5386+ parts = [
5387+ ToolCallPart (
5388+ tool_name = 'final_result' ,
5389+ args = {'dice_roll' : 4 },
5390+ tool_call_id = 'pyd_ai_tool_call_id__final_result' ,
5391+ )
5392+ ]
5393+ )
5394+ return ModelResponse (
5395+ parts = [ToolCallPart (tool_name = 'roll_dice' , args = {}, tool_call_id = 'pyd_ai_tool_call_id__roll_dice' )]
5396+ )
5397+
5398+ class Result (BaseModel ):
5399+ dice_roll : int
5400+
5401+ agent = Agent (FunctionModel (llm ), output_type = Result )
5402+
5403+ @agent .tool_plain
5404+ def roll_dice () -> int :
5405+ return 4
5406+
5407+ result = agent .run_sync ('Roll me a dice.' )
5408+ messages = result .all_messages ()
5409+ assert messages == snapshot (
5410+ [
5411+ ModelRequest (
5412+ parts = [
5413+ UserPromptPart (
5414+ content = 'Roll me a dice.' ,
5415+ timestamp = IsDatetime (),
5416+ )
5417+ ]
5418+ ),
5419+ ModelResponse (
5420+ parts = [ToolCallPart (tool_name = 'roll_dice' , args = {}, tool_call_id = 'pyd_ai_tool_call_id__roll_dice' )],
5421+ usage = RequestUsage (input_tokens = 55 , output_tokens = 2 ),
5422+ model_name = 'function:llm:' ,
5423+ timestamp = IsDatetime (),
5424+ ),
5425+ ModelRequest (
5426+ parts = [
5427+ ToolReturnPart (
5428+ tool_name = 'roll_dice' ,
5429+ content = 4 ,
5430+ tool_call_id = 'pyd_ai_tool_call_id__roll_dice' ,
5431+ timestamp = IsDatetime (),
5432+ )
5433+ ]
5434+ ),
5435+ ModelResponse (
5436+ parts = [
5437+ ToolCallPart (
5438+ tool_name = 'final_result' ,
5439+ args = {'dice_roll' : 4 },
5440+ tool_call_id = 'pyd_ai_tool_call_id__final_result' ,
5441+ )
5442+ ],
5443+ usage = RequestUsage (input_tokens = 56 , output_tokens = 6 ),
5444+ model_name = 'function:llm:' ,
5445+ timestamp = IsDatetime (),
5446+ ),
5447+ ModelRequest (
5448+ parts = [
5449+ ToolReturnPart (
5450+ tool_name = 'final_result' ,
5451+ content = 'Final result processed.' ,
5452+ tool_call_id = 'pyd_ai_tool_call_id__final_result' ,
5453+ timestamp = IsDatetime (),
5454+ )
5455+ ]
5456+ ),
5457+ ]
5458+ )
5459+
5460+ result = agent .run_sync ('Roll me a dice again.' , message_history = messages )
5461+ new_messages = result .new_messages ()
5462+ assert new_messages == snapshot (
5463+ [
5464+ ModelRequest (
5465+ parts = [
5466+ UserPromptPart (
5467+ content = 'Roll me a dice again.' ,
5468+ timestamp = IsDatetime (),
5469+ )
5470+ ]
5471+ ),
5472+ ModelResponse (
5473+ parts = [ToolCallPart (tool_name = 'roll_dice' , args = {}, tool_call_id = 'pyd_ai_tool_call_id__roll_dice' )],
5474+ usage = RequestUsage (input_tokens = 66 , output_tokens = 8 ),
5475+ model_name = 'function:llm:' ,
5476+ timestamp = IsDatetime (),
5477+ ),
5478+ ModelRequest (
5479+ parts = [
5480+ ToolReturnPart (
5481+ tool_name = 'roll_dice' ,
5482+ content = 4 ,
5483+ tool_call_id = 'pyd_ai_tool_call_id__roll_dice' ,
5484+ timestamp = IsDatetime (),
5485+ )
5486+ ]
5487+ ),
5488+ ModelResponse (
5489+ parts = [
5490+ ToolCallPart (
5491+ tool_name = 'final_result' ,
5492+ args = {'dice_roll' : 4 },
5493+ tool_call_id = 'pyd_ai_tool_call_id__final_result' ,
5494+ )
5495+ ],
5496+ usage = RequestUsage (input_tokens = 67 , output_tokens = 12 ),
5497+ model_name = 'function:llm:' ,
5498+ timestamp = IsDatetime (),
5499+ ),
5500+ ModelRequest (
5501+ parts = [
5502+ ToolReturnPart (
5503+ tool_name = 'final_result' ,
5504+ content = 'Final result processed.' ,
5505+ tool_call_id = 'pyd_ai_tool_call_id__final_result' ,
5506+ timestamp = IsDatetime (),
5507+ )
5508+ ]
5509+ ),
5510+ ]
5511+ )
5512+
5513+ assert not any (isinstance (p , ToolReturnPart ) and p .tool_name == 'final_result' for p in new_messages [0 ].parts )
0 commit comments