@@ -5884,3 +5884,119 @@ async def test_user_message_then_before_run_same_trace_no_ambient(
58845884 )
58855885
58865886 provider .shutdown ()
5887+
5888+
5889+ class TestRootAgentNameAcrossInvocations :
5890+ """Regression: root_agent_name must refresh across invocations."""
5891+
5892+ @pytest .mark .asyncio
5893+ async def test_root_agent_name_updates_between_invocations (
5894+ self ,
5895+ bq_plugin_inst ,
5896+ mock_write_client ,
5897+ mock_session ,
5898+ dummy_arrow_schema ,
5899+ ):
5900+ """Two invocations with different root agents must log correct names.
5901+
5902+ Previously init_trace() only set _root_agent_name_ctx when it was
5903+ None, so the second invocation would inherit the first's root agent.
5904+ """
5905+ from opentelemetry .sdk .trace import TracerProvider as SdkProvider
5906+ from opentelemetry .sdk .trace .export import SimpleSpanProcessor
5907+ from opentelemetry .sdk .trace .export .in_memory_span_exporter import InMemorySpanExporter
5908+
5909+ provider = SdkProvider ()
5910+ provider .add_span_processor (SimpleSpanProcessor (InMemorySpanExporter ()))
5911+ real_tracer = provider .get_tracer ("test" )
5912+
5913+ mock_session_service = mock .create_autospec (
5914+ base_session_service_lib .BaseSessionService ,
5915+ instance = True ,
5916+ spec_set = True ,
5917+ )
5918+ mock_plugin_manager = mock .create_autospec (
5919+ plugin_manager_lib .PluginManager ,
5920+ instance = True ,
5921+ spec_set = True ,
5922+ )
5923+
5924+ def _make_inv_ctx (agent_name , inv_id ):
5925+ agent = mock .create_autospec (
5926+ base_agent .BaseAgent , instance = True , spec_set = True
5927+ )
5928+ type(agent ).name = mock .PropertyMock (return_value = agent_name )
5929+ type(agent ).instruction = mock .PropertyMock (return_value = "" )
5930+ # root_agent returns itself (no parent).
5931+ agent .root_agent = agent
5932+ return invocation_context_lib .InvocationContext (
5933+ agent = agent ,
5934+ session = mock_session ,
5935+ invocation_id = inv_id ,
5936+ session_service = mock_session_service ,
5937+ plugin_manager = mock_plugin_manager ,
5938+ )
5939+
5940+ with mock .patch .object (
5941+ bigquery_agent_analytics_plugin , "tracer" , real_tracer
5942+ ):
5943+ # --- Invocation 1: root agent = "RootA" ---
5944+ bigquery_agent_analytics_plugin ._span_records_ctx .set (None )
5945+ bigquery_agent_analytics_plugin ._active_invocation_id_ctx .set (None )
5946+ bigquery_agent_analytics_plugin ._root_agent_name_ctx .set (None )
5947+
5948+ inv1 = _make_inv_ctx ("RootA" , "inv-001" )
5949+ cb1 = callback_context_lib .CallbackContext (inv1 )
5950+ await bq_plugin_inst .before_run_callback (invocation_context = inv1 )
5951+ await bq_plugin_inst .before_agent_callback (
5952+ agent = inv1 .agent , callback_context = cb1
5953+ )
5954+ await bq_plugin_inst .after_agent_callback (
5955+ agent = inv1 .agent , callback_context = cb1
5956+ )
5957+ await bq_plugin_inst .after_run_callback (invocation_context = inv1 )
5958+ await asyncio .sleep (0.01 )
5959+
5960+ rows_inv1 = await _get_captured_rows_async (
5961+ mock_write_client , dummy_arrow_schema
5962+ )
5963+
5964+ # --- Invocation 2: root agent = "RootB" ---
5965+ mock_write_client .append_rows .reset_mock ()
5966+
5967+ inv2 = _make_inv_ctx ("RootB" , "inv-002" )
5968+ cb2 = callback_context_lib .CallbackContext (inv2 )
5969+ await bq_plugin_inst .before_run_callback (invocation_context = inv2 )
5970+ await bq_plugin_inst .before_agent_callback (
5971+ agent = inv2 .agent , callback_context = cb2
5972+ )
5973+ await bq_plugin_inst .after_agent_callback (
5974+ agent = inv2 .agent , callback_context = cb2
5975+ )
5976+ await bq_plugin_inst .after_run_callback (invocation_context = inv2 )
5977+ await asyncio .sleep (0.01 )
5978+
5979+ rows_inv2 = await _get_captured_rows_async (
5980+ mock_write_client , dummy_arrow_schema
5981+ )
5982+
5983+ # Parse root_agent_name from the attributes JSON column.
5984+ def _get_root_names (rows ):
5985+ names = set ()
5986+ for r in rows :
5987+ attrs = r .get ("attributes" )
5988+ if attrs :
5989+ parsed = json .loads (attrs ) if isinstance (attrs , str ) else attrs
5990+ if "root_agent_name" in parsed :
5991+ names .add (parsed ["root_agent_name" ])
5992+ return names
5993+
5994+ names_inv1 = _get_root_names (rows_inv1 )
5995+ names_inv2 = _get_root_names (rows_inv2 )
5996+
5997+ # Invocation 1 should only have "RootA".
5998+ assert names_inv1 == {"RootA" }, f"Expected {{'RootA'}}, got { names_inv1 } "
5999+ # Invocation 2 must have "RootB", NOT stale "RootA".
6000+ assert names_inv2 == {"RootB" }, f"Expected {{'RootB'}}, got { names_inv2 } "
6001+
6002+ provider .shutdown ()
0 commit comments