Skip to content

Commit b652f81

Browse files
caohy1988claude
andcommitted
fix: refresh root_agent_name on each invocation, not just first
init_trace() previously only set _root_agent_name_ctx when it was None, so the second invocation with a different root agent would inherit the first's name. Now it sets unconditionally. after_run_callback also resets _root_agent_name_ctx alongside the other invocation cleanup. Also adds a NOTE comment acknowledging that trace contextvars are module-global (not plugin-instance-scoped) as a known limitation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent e228b35 commit b652f81

File tree

2 files changed

+129
-6
lines changed

2 files changed

+129
-6
lines changed

src/google/adk/plugins/bigquery_agent_analytics_plugin.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,11 @@ class BigQueryLoggerConfig:
507507
# ==============================================================================
508508
# HELPER: TRACE MANAGER (Async-Safe with ContextVars)
509509
# ==============================================================================
510+
# NOTE: These contextvars are module-global, not plugin-instance-scoped.
511+
# Multiple BigQueryAgentAnalyticsPlugin instances in the same execution
512+
# context will share trace state. This is acceptable for the expected
513+
# single-plugin-per-process deployment, but should be revisited if
514+
# multi-instance support is needed (e.g. scope by plugin instance ID).
510515

511516
_root_agent_name_ctx = contextvars.ContextVar(
512517
"_bq_analytics_root_agent_name", default=None
@@ -564,12 +569,13 @@ def _get_records() -> list[_SpanRecord]:
564569

565570
@staticmethod
566571
def init_trace(callback_context: CallbackContext) -> None:
567-
if _root_agent_name_ctx.get() is None:
568-
try:
569-
root_agent = callback_context._invocation_context.agent.root_agent
570-
_root_agent_name_ctx.set(root_agent.name)
571-
except (AttributeError, ValueError):
572-
pass
572+
# Always refresh root_agent_name — it can change between
573+
# invocations (e.g. different root agents in the same task).
574+
try:
575+
root_agent = callback_context._invocation_context.agent.root_agent
576+
_root_agent_name_ctx.set(root_agent.name)
577+
except (AttributeError, ValueError):
578+
pass
573579

574580
# Ensure records stack is initialized
575581
TraceManager._get_records()
@@ -2734,6 +2740,7 @@ async def after_run_callback(
27342740
# invocation to prevent leaks into the next one.
27352741
TraceManager.clear_stack()
27362742
_active_invocation_id_ctx.set(None)
2743+
_root_agent_name_ctx.set(None)
27372744
# Ensure all logs are flushed before the agent returns
27382745
await self.flush()
27392746

tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5884,3 +5884,119 @@ async def test_user_message_then_before_run_same_trace_no_ambient(
58845884
)
58855885

58865886
provider.shutdown()
5887+
5888+
5889+
class TestRootAgentNameAcrossInvocations:
5890+
"""Regression: root_agent_name must refresh across invocations."""
5891+
5892+
@pytest.mark.asyncio
5893+
async def test_root_agent_name_updates_between_invocations(
5894+
self,
5895+
bq_plugin_inst,
5896+
mock_write_client,
5897+
mock_session,
5898+
dummy_arrow_schema,
5899+
):
5900+
"""Two invocations with different root agents must log correct names.
5901+
5902+
Previously init_trace() only set _root_agent_name_ctx when it was
5903+
None, so the second invocation would inherit the first's root agent.
5904+
"""
5905+
from opentelemetry.sdk.trace import TracerProvider as SdkProvider
5906+
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
5907+
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
5908+
5909+
provider = SdkProvider()
5910+
provider.add_span_processor(SimpleSpanProcessor(InMemorySpanExporter()))
5911+
real_tracer = provider.get_tracer("test")
5912+
5913+
mock_session_service = mock.create_autospec(
5914+
base_session_service_lib.BaseSessionService,
5915+
instance=True,
5916+
spec_set=True,
5917+
)
5918+
mock_plugin_manager = mock.create_autospec(
5919+
plugin_manager_lib.PluginManager,
5920+
instance=True,
5921+
spec_set=True,
5922+
)
5923+
5924+
def _make_inv_ctx(agent_name, inv_id):
5925+
agent = mock.create_autospec(
5926+
base_agent.BaseAgent, instance=True, spec_set=True
5927+
)
5928+
type(agent).name = mock.PropertyMock(return_value=agent_name)
5929+
type(agent).instruction = mock.PropertyMock(return_value="")
5930+
# root_agent returns itself (no parent).
5931+
agent.root_agent = agent
5932+
return invocation_context_lib.InvocationContext(
5933+
agent=agent,
5934+
session=mock_session,
5935+
invocation_id=inv_id,
5936+
session_service=mock_session_service,
5937+
plugin_manager=mock_plugin_manager,
5938+
)
5939+
5940+
with mock.patch.object(
5941+
bigquery_agent_analytics_plugin, "tracer", real_tracer
5942+
):
5943+
# --- Invocation 1: root agent = "RootA" ---
5944+
bigquery_agent_analytics_plugin._span_records_ctx.set(None)
5945+
bigquery_agent_analytics_plugin._active_invocation_id_ctx.set(None)
5946+
bigquery_agent_analytics_plugin._root_agent_name_ctx.set(None)
5947+
5948+
inv1 = _make_inv_ctx("RootA", "inv-001")
5949+
cb1 = callback_context_lib.CallbackContext(inv1)
5950+
await bq_plugin_inst.before_run_callback(invocation_context=inv1)
5951+
await bq_plugin_inst.before_agent_callback(
5952+
agent=inv1.agent, callback_context=cb1
5953+
)
5954+
await bq_plugin_inst.after_agent_callback(
5955+
agent=inv1.agent, callback_context=cb1
5956+
)
5957+
await bq_plugin_inst.after_run_callback(invocation_context=inv1)
5958+
await asyncio.sleep(0.01)
5959+
5960+
rows_inv1 = await _get_captured_rows_async(
5961+
mock_write_client, dummy_arrow_schema
5962+
)
5963+
5964+
# --- Invocation 2: root agent = "RootB" ---
5965+
mock_write_client.append_rows.reset_mock()
5966+
5967+
inv2 = _make_inv_ctx("RootB", "inv-002")
5968+
cb2 = callback_context_lib.CallbackContext(inv2)
5969+
await bq_plugin_inst.before_run_callback(invocation_context=inv2)
5970+
await bq_plugin_inst.before_agent_callback(
5971+
agent=inv2.agent, callback_context=cb2
5972+
)
5973+
await bq_plugin_inst.after_agent_callback(
5974+
agent=inv2.agent, callback_context=cb2
5975+
)
5976+
await bq_plugin_inst.after_run_callback(invocation_context=inv2)
5977+
await asyncio.sleep(0.01)
5978+
5979+
rows_inv2 = await _get_captured_rows_async(
5980+
mock_write_client, dummy_arrow_schema
5981+
)
5982+
5983+
# Parse root_agent_name from the attributes JSON column.
5984+
def _get_root_names(rows):
5985+
names = set()
5986+
for r in rows:
5987+
attrs = r.get("attributes")
5988+
if attrs:
5989+
parsed = json.loads(attrs) if isinstance(attrs, str) else attrs
5990+
if "root_agent_name" in parsed:
5991+
names.add(parsed["root_agent_name"])
5992+
return names
5993+
5994+
names_inv1 = _get_root_names(rows_inv1)
5995+
names_inv2 = _get_root_names(rows_inv2)
5996+
5997+
# Invocation 1 should only have "RootA".
5998+
assert names_inv1 == {"RootA"}, f"Expected {{'RootA'}}, got {names_inv1}"
5999+
# Invocation 2 must have "RootB", NOT stale "RootA".
6000+
assert names_inv2 == {"RootB"}, f"Expected {{'RootB'}}, got {names_inv2}"
6001+
6002+
provider.shutdown()

0 commit comments

Comments
 (0)