Skip to content

Commit bcbfeba

Browse files
leonzzcopybara-github
authored andcommitted
feat: pass trace context in MCP tool call's _meta field with Otel propagator
PiperOrigin-RevId: 868841079
1 parent 9dccd6a commit bcbfeba

File tree

2 files changed

+65
-4
lines changed

2 files changed

+65
-4
lines changed

src/google/adk/tools/mcp_tool/mcp_tool.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from google.genai.types import FunctionDeclaration
3232
from mcp.shared.session import ProgressFnT
3333
from mcp.types import Tool as McpBaseTool
34+
from opentelemetry import propagate
3435
from typing_extensions import override
3536

3637
from ...agents.callback_context import CallbackContext
@@ -313,6 +314,12 @@ async def _run_async_impl(
313314
headers.update(dynamic_headers)
314315
final_headers = headers if headers else None
315316

317+
# Propagate trace context in the _meta field as sprcified by MCP protocol.
318+
# See https://agentclientprotocol.com/protocol/extensibility#the-meta-field
319+
trace_carrier: Dict[str, str] = {}
320+
propagate.get_global_textmap().inject(carrier=trace_carrier)
321+
meta_trace_context = trace_carrier if trace_carrier else None
322+
316323
# Get the session from the session manager
317324
session = await self._mcp_session_manager.create_session(
318325
headers=final_headers
@@ -325,6 +332,7 @@ async def _run_async_impl(
325332
self._mcp_tool.name,
326333
arguments=args,
327334
progress_callback=resolved_callback,
335+
meta=meta_trace_context,
328336
)
329337
return response.model_dump(exclude_none=True, mode="json")
330338

tests/unittests/tools/mcp_tool/test_mcp_tool.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from google.adk.auth.auth_credential import ServiceAccount
2626
from google.adk.features import FeatureName
2727
from google.adk.features._feature_registry import temporary_feature_override
28+
from google.adk.tools.mcp_tool import mcp_tool
2829
from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
2930
from google.adk.tools.mcp_tool.mcp_tool import MCPTool
3031
from google.adk.tools.tool_context import ToolContext
@@ -225,7 +226,7 @@ async def test_run_async_impl_no_auth(self):
225226
)
226227
# Fix: call_tool uses 'arguments' parameter, not positional args
227228
self.mock_session.call_tool.assert_called_once_with(
228-
"test_tool", arguments=args, progress_callback=None
229+
"test_tool", arguments=args, progress_callback=None, meta=None
229230
)
230231

231232
@pytest.mark.asyncio
@@ -262,6 +263,55 @@ async def test_run_async_impl_with_oauth2(self):
262263
headers = call_args[1]["headers"]
263264
assert headers == {"Authorization": "Bearer test_access_token"}
264265

266+
@patch.object(mcp_tool, "propagate", autospec=True)
267+
@pytest.mark.asyncio
268+
async def test_run_async_impl_with_trace_context(self, mock_propagate):
269+
"""Test running tool with trace context injection."""
270+
mock_propagator = Mock()
271+
272+
def inject_context(carrier, context=None) -> None:
273+
carrier["traceparent"] = (
274+
"00-1234567890abcdef1234567890abcdef-1234567890abcdef-01"
275+
)
276+
carrier["tracestate"] = "foo=bar"
277+
carrier["baggage"] = "baz=qux"
278+
279+
mock_propagator.inject.side_effect = inject_context
280+
mock_propagate.get_global_textmap.return_value = mock_propagator
281+
282+
tool = MCPTool(
283+
mcp_tool=self.mock_mcp_tool,
284+
mcp_session_manager=self.mock_session_manager,
285+
)
286+
287+
mcp_response = CallToolResult(
288+
content=[TextContent(type="text", text="success")]
289+
)
290+
self.mock_session.call_tool = AsyncMock(return_value=mcp_response)
291+
292+
tool_context = Mock(spec=ToolContext)
293+
args = {"param1": "test_value"}
294+
295+
await tool._run_async_impl(
296+
args=args, tool_context=tool_context, credential=None
297+
)
298+
299+
self.mock_session_manager.create_session.assert_called_once_with(
300+
headers=None
301+
)
302+
self.mock_session.call_tool.assert_called_once_with(
303+
"test_tool",
304+
arguments=args,
305+
progress_callback=None,
306+
meta={
307+
"traceparent": (
308+
"00-1234567890abcdef1234567890abcdef-1234567890abcdef-01"
309+
),
310+
"tracestate": "foo=bar",
311+
"baggage": "baz=qux",
312+
},
313+
)
314+
265315
@pytest.mark.asyncio
266316
async def test_get_headers_oauth2(self):
267317
"""Test header generation for OAuth2 credentials."""
@@ -778,7 +828,7 @@ async def test_run_async_impl_with_header_provider_no_auth(self):
778828
headers=expected_headers
779829
)
780830
self.mock_session.call_tool.assert_called_once_with(
781-
"test_tool", arguments=args, progress_callback=None
831+
"test_tool", arguments=args, progress_callback=None, meta=None
782832
)
783833

784834
@pytest.mark.asyncio
@@ -821,7 +871,7 @@ async def test_run_async_impl_with_header_provider_and_oauth2(self):
821871
"X-Tenant-ID": "test-tenant",
822872
}
823873
self.mock_session.call_tool.assert_called_once_with(
824-
"test_tool", arguments=args, progress_callback=None
874+
"test_tool", arguments=args, progress_callback=None, meta=None
825875
)
826876

827877
def test_init_with_progress_callback(self):
@@ -875,7 +925,10 @@ async def my_progress_callback(
875925
)
876926
# Verify progress_callback was passed to call_tool
877927
self.mock_session.call_tool.assert_called_once_with(
878-
"test_tool", arguments=args, progress_callback=my_progress_callback
928+
"test_tool",
929+
arguments=args,
930+
progress_callback=my_progress_callback,
931+
meta=None,
879932
)
880933

881934
@pytest.mark.asyncio

0 commit comments

Comments
 (0)