Skip to content

Commit 559a832

Browse files
authored
fix: #1840 roll back session changes when Guardrail tripwire is triggered (#1843)
1 parent 9407743 commit 559a832

File tree

4 files changed

+153
-10
lines changed

4 files changed

+153
-10
lines changed

src/agents/run.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,13 @@ async def run(
665665
tool_output_guardrail_results=tool_output_guardrail_results,
666666
context_wrapper=context_wrapper,
667667
)
668-
await self._save_result_to_session(session, [], turn_result.new_step_items)
668+
if not any(
669+
guardrail_result.output.tripwire_triggered
670+
for guardrail_result in input_guardrail_results
671+
):
672+
await self._save_result_to_session(
673+
session, [], turn_result.new_step_items
674+
)
669675

670676
return result
671677
elif isinstance(turn_result.next_step, NextStepHandoff):
@@ -674,7 +680,13 @@ async def run(
674680
current_span = None
675681
should_run_agent_start_hooks = True
676682
elif isinstance(turn_result.next_step, NextStepRunAgain):
677-
await self._save_result_to_session(session, [], turn_result.new_step_items)
683+
if not any(
684+
guardrail_result.output.tripwire_triggered
685+
for guardrail_result in input_guardrail_results
686+
):
687+
await self._save_result_to_session(
688+
session, [], turn_result.new_step_items
689+
)
678690
else:
679691
raise AgentsException(
680692
f"Unknown next step type: {type(turn_result.next_step)}"
@@ -1043,15 +1055,29 @@ async def _start_streaming(
10431055
streamed_result.is_complete = True
10441056

10451057
# Save the conversation to session if enabled
1046-
await AgentRunner._save_result_to_session(
1047-
session, [], turn_result.new_step_items
1048-
)
1058+
if session is not None:
1059+
should_skip_session_save = (
1060+
await AgentRunner._input_guardrail_tripwire_triggered_for_stream(
1061+
streamed_result
1062+
)
1063+
)
1064+
if should_skip_session_save is False:
1065+
await AgentRunner._save_result_to_session(
1066+
session, [], turn_result.new_step_items
1067+
)
10491068

10501069
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
10511070
elif isinstance(turn_result.next_step, NextStepRunAgain):
1052-
await AgentRunner._save_result_to_session(
1053-
session, [], turn_result.new_step_items
1054-
)
1071+
if session is not None:
1072+
should_skip_session_save = (
1073+
await AgentRunner._input_guardrail_tripwire_triggered_for_stream(
1074+
streamed_result
1075+
)
1076+
)
1077+
if should_skip_session_save is False:
1078+
await AgentRunner._save_result_to_session(
1079+
session, [], turn_result.new_step_items
1080+
)
10551081
except AgentsException as exc:
10561082
streamed_result.is_complete = True
10571083
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
@@ -1746,6 +1772,24 @@ async def _save_result_to_session(
17461772
items_to_save = input_list + new_items_as_input
17471773
await session.add_items(items_to_save)
17481774

1775+
@staticmethod
1776+
async def _input_guardrail_tripwire_triggered_for_stream(
1777+
streamed_result: RunResultStreaming,
1778+
) -> bool:
1779+
"""Return True if any input guardrail triggered during a streamed run."""
1780+
1781+
task = streamed_result._input_guardrails_task
1782+
if task is None:
1783+
return False
1784+
1785+
if not task.done():
1786+
await task
1787+
1788+
return any(
1789+
guardrail_result.output.tripwire_triggered
1790+
for guardrail_result in streamed_result.input_guardrail_results
1791+
)
1792+
17491793

17501794
DEFAULT_AGENT_RUNNER = AgentRunner()
17511795
_TOOL_CALL_TYPES: tuple[type, ...] = get_args(ToolCallItemTypes)

tests/test_agent_runner.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import json
45
import tempfile
56
from pathlib import Path
6-
from typing import Any
7+
from typing import Any, cast
78
from unittest.mock import patch
89

910
import pytest
@@ -39,6 +40,7 @@
3940
get_text_input_item,
4041
get_text_message,
4142
)
43+
from .utils.simple_session import SimpleListSession
4244

4345

4446
@pytest.mark.asyncio
@@ -542,6 +544,40 @@ def guardrail_function(
542544
await Runner.run(agent, input="user_message")
543545

544546

547+
@pytest.mark.asyncio
548+
async def test_input_guardrail_tripwire_does_not_save_assistant_message_to_session():
549+
async def guardrail_function(
550+
context: RunContextWrapper[Any], agent: Agent[Any], input: Any
551+
) -> GuardrailFunctionOutput:
552+
# Delay to ensure the agent has time to produce output before the guardrail finishes.
553+
await asyncio.sleep(0.01)
554+
return GuardrailFunctionOutput(
555+
output_info=None,
556+
tripwire_triggered=True,
557+
)
558+
559+
session = SimpleListSession()
560+
561+
model = FakeModel()
562+
model.set_next_output([get_text_message("should_not_be_saved")])
563+
564+
agent = Agent(
565+
name="test",
566+
model=model,
567+
input_guardrails=[InputGuardrail(guardrail_function=guardrail_function)],
568+
)
569+
570+
with pytest.raises(InputGuardrailTripwireTriggered):
571+
await Runner.run(agent, input="user_message", session=session)
572+
573+
items = await session.get_items()
574+
575+
assert len(items) == 1
576+
first_item = cast(dict[str, Any], items[0])
577+
assert "role" in first_item
578+
assert first_item["role"] == "user"
579+
580+
545581
@pytest.mark.asyncio
546582
async def test_output_guardrail_tripwire_triggered_causes_exception():
547583
def guardrail_function(

tests/test_agent_runner_streamed.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import asyncio
44
import json
5-
from typing import Any
5+
from typing import Any, cast
66

77
import pytest
88
from typing_extensions import TypedDict
@@ -35,6 +35,7 @@
3535
get_text_input_item,
3636
get_text_message,
3737
)
38+
from .utils.simple_session import SimpleListSession
3839

3940

4041
@pytest.mark.asyncio
@@ -524,6 +525,38 @@ def guardrail_function(
524525
pass
525526

526527

528+
@pytest.mark.asyncio
529+
async def test_input_guardrail_streamed_does_not_save_assistant_message_to_session():
530+
async def guardrail_function(
531+
context: RunContextWrapper[Any], agent: Agent[Any], input: Any
532+
) -> GuardrailFunctionOutput:
533+
await asyncio.sleep(0.01)
534+
return GuardrailFunctionOutput(output_info=None, tripwire_triggered=True)
535+
536+
session = SimpleListSession()
537+
538+
model = FakeModel()
539+
model.set_next_output([get_text_message("should_not_be_saved")])
540+
541+
agent = Agent(
542+
name="test",
543+
model=model,
544+
input_guardrails=[InputGuardrail(guardrail_function=guardrail_function)],
545+
)
546+
547+
with pytest.raises(InputGuardrailTripwireTriggered):
548+
result = Runner.run_streamed(agent, input="user_message", session=session)
549+
async for _ in result.stream_events():
550+
pass
551+
552+
items = await session.get_items()
553+
554+
assert len(items) == 1
555+
first_item = cast(dict[str, Any], items[0])
556+
assert "role" in first_item
557+
assert first_item["role"] == "user"
558+
559+
527560
@pytest.mark.asyncio
528561
async def test_slow_input_guardrail_still_raises_exception_streamed():
529562
async def guardrail_function(

tests/utils/simple_session.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from __future__ import annotations
2+
3+
from agents.items import TResponseInputItem
4+
from agents.memory.session import Session
5+
6+
7+
class SimpleListSession(Session):
8+
"""A minimal in-memory session implementation for tests."""
9+
10+
def __init__(self, session_id: str = "test") -> None:
11+
self.session_id = session_id
12+
self._items: list[TResponseInputItem] = []
13+
14+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
15+
if limit is None:
16+
return list(self._items)
17+
if limit <= 0:
18+
return []
19+
return self._items[-limit:]
20+
21+
async def add_items(self, items: list[TResponseInputItem]) -> None:
22+
self._items.extend(items)
23+
24+
async def pop_item(self) -> TResponseInputItem | None:
25+
if not self._items:
26+
return None
27+
return self._items.pop()
28+
29+
async def clear_session(self) -> None:
30+
self._items.clear()

0 commit comments

Comments
 (0)