Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions src/agents/tracing/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,25 @@ def set_current_span(cls, span: "Span[Any] | None") -> "contextvars.Token[Span[A
return _current_span.set(span)

@classmethod
def reset_current_span(cls, token: "contextvars.Token[Span[Any] | None]") -> None:
_current_span.reset(token)
def reset_current_span(
cls,
token: "contextvars.Token[Span[Any] | None]",
prev_span: "Span[Any] | None" = None,
) -> None:
try:
_current_span.reset(token)
except ValueError:
# Token was created in a different Context. This can happen when multiple
# Runner.run() calls execute concurrently via asyncio.gather().
# Fall back to setting the previous value directly.
# See: https://github.com/openai/openai-agents-python/issues/2246
logger.warning(
"Tracing context mismatch detected during concurrent execution. "
"Span context was reset using fallback. This may affect trace hierarchy "
"in concurrent scenarios. Consider using asyncio.create_task() for concurrent "
"Runner.run() calls to ensure proper context isolation."
)
_current_span.set(prev_span)

@classmethod
def get_current_trace(cls) -> "Trace | None":
Expand All @@ -44,6 +61,29 @@ def set_current_trace(cls, trace: "Trace | None") -> "contextvars.Token[Trace |
return _current_trace.set(trace)

@classmethod
def reset_current_trace(cls, token: "contextvars.Token[Trace | None]") -> None:
def reset_current_trace(
cls,
token: "contextvars.Token[Trace | None]",
prev_trace: "Trace | None" = None,
) -> None:
"""Reset the current trace to its previous value.

Uses token-based reset when possible, with fallback to direct set for
concurrent execution scenarios where Context objects may differ.
See: https://github.com/openai/openai-agents-python/issues/2246
"""
logger.debug("Resetting current trace")
_current_trace.reset(token)
try:
_current_trace.reset(token)
except ValueError:
# Token was created in a different Context. This can happen when multiple
# Runner.run() calls execute concurrently via asyncio.gather().
# Fall back to setting the previous value directly.
# See: https://github.com/openai/openai-agents-python/issues/2246
logger.warning(
"Tracing context mismatch detected during concurrent execution. "
"Trace context was reset using fallback. This may affect trace hierarchy "
"in concurrent scenarios. Consider using asyncio.create_task() for concurrent "
"Runner.run() calls to ensure proper context isolation."
)
_current_trace.set(prev_trace)
13 changes: 10 additions & 3 deletions src/agents/tracing/spans.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,12 @@ class NoOpSpan(Span[TSpanData]):
span_data: The operation-specific data for this span.
"""

__slots__ = ("_span_data", "_prev_span_token")
__slots__ = ("_span_data", "_prev_span_token", "_prev_span")

def __init__(self, span_data: TSpanData):
self._span_data = span_data
self._prev_span_token: contextvars.Token[Span[TSpanData] | None] | None = None
self._prev_span: Span[Any] | None = None

@property
def trace_id(self) -> str:
Expand All @@ -206,12 +207,14 @@ def parent_id(self) -> str | None:

def start(self, mark_as_current: bool = False):
if mark_as_current:
self._prev_span = Scope.get_current_span()
self._prev_span_token = Scope.set_current_span(self)

def finish(self, reset_current: bool = False) -> None:
if reset_current and self._prev_span_token is not None:
Scope.reset_current_span(self._prev_span_token)
Scope.reset_current_span(self._prev_span_token, self._prev_span)
self._prev_span_token = None
self._prev_span = None

def __enter__(self) -> Span[TSpanData]:
self.start(mark_as_current=True)
Expand Down Expand Up @@ -253,6 +256,7 @@ class SpanImpl(Span[TSpanData]):
"_ended_at",
"_error",
"_prev_span_token",
"_prev_span",
"_processor",
"_span_data",
)
Expand All @@ -273,6 +277,7 @@ def __init__(
self._processor = processor
self._error: SpanError | None = None
self._prev_span_token: contextvars.Token[Span[TSpanData] | None] | None = None
self._prev_span: Span[Any] | None = None
self._span_data = span_data

@property
Expand All @@ -299,6 +304,7 @@ def start(self, mark_as_current: bool = False):
self._started_at = util.time_iso()
self._processor.on_span_start(self)
if mark_as_current:
self._prev_span = Scope.get_current_span()
self._prev_span_token = Scope.set_current_span(self)

def finish(self, reset_current: bool = False) -> None:
Expand All @@ -309,8 +315,9 @@ def finish(self, reset_current: bool = False) -> None:
self._ended_at = util.time_iso()
self._processor.on_span_end(self)
if reset_current and self._prev_span_token is not None:
Scope.reset_current_span(self._prev_span_token)
Scope.reset_current_span(self._prev_span_token, self._prev_span)
self._prev_span_token = None
self._prev_span = None

def __enter__(self) -> Span[TSpanData]:
self.start(mark_as_current=True)
Expand Down
11 changes: 9 additions & 2 deletions src/agents/tracing/traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class NoOpTrace(Trace):
def __init__(self):
self._started = False
self._prev_context_token: contextvars.Token[Trace | None] | None = None
self._prev_trace: Trace | None = None

def __enter__(self) -> Trace:
if self._started:
Expand All @@ -160,12 +161,14 @@ def __exit__(self, exc_type, exc_val, exc_tb):

def start(self, mark_as_current: bool = False):
if mark_as_current:
self._prev_trace = Scope.get_current_trace()
self._prev_context_token = Scope.set_current_trace(self)

def finish(self, reset_current: bool = False):
if reset_current and self._prev_context_token is not None:
Scope.reset_current_trace(self._prev_context_token)
Scope.reset_current_trace(self._prev_context_token, self._prev_trace)
self._prev_context_token = None
self._prev_trace = None

@property
def trace_id(self) -> str:
Expand Down Expand Up @@ -208,6 +211,7 @@ class TraceImpl(Trace):
"group_id",
"metadata",
"_prev_context_token",
"_prev_trace",
"_processor",
"_started",
)
Expand All @@ -225,6 +229,7 @@ def __init__(
self.group_id = group_id
self.metadata = metadata
self._prev_context_token: contextvars.Token[Trace | None] | None = None
self._prev_trace: Trace | None = None
self._processor = processor
self._started = False

Expand All @@ -244,6 +249,7 @@ def start(self, mark_as_current: bool = False):
self._processor.on_trace_start(self)

if mark_as_current:
self._prev_trace = Scope.get_current_trace()
self._prev_context_token = Scope.set_current_trace(self)

def finish(self, reset_current: bool = False):
Expand All @@ -253,8 +259,9 @@ def finish(self, reset_current: bool = False):
self._processor.on_trace_end(self)

if reset_current and self._prev_context_token is not None:
Scope.reset_current_trace(self._prev_context_token)
Scope.reset_current_trace(self._prev_context_token, self._prev_trace)
self._prev_context_token = None
self._prev_trace = None

def __enter__(self) -> Trace:
if self._started:
Expand Down