diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 5098425e93..a4f83a15a3 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -49,6 +49,7 @@ from litellm import Message from litellm import ModelResponse from litellm import OpenAIMessageContent +from opentelemetry import trace from pydantic import BaseModel from pydantic import Field from typing_extensions import override @@ -225,6 +226,39 @@ class UsageMetadataChunk(BaseModel): class LiteLLMClient: """Provides acompletion method (for better testability).""" + @staticmethod + def _build_traceparent() -> Optional[str]: + span_context = trace.get_current_span().get_span_context() + if not span_context.is_valid: + return None + + trace_id = f"{span_context.trace_id:032x}" + span_id = f"{span_context.span_id:016x}" + trace_flags = f"{int(span_context.trace_flags):02x}" + return f"00-{trace_id}-{span_id}-{trace_flags}" + + @classmethod + def _maybe_add_traceparent_header( + cls, extra_headers: Optional[dict[str, str]] + ) -> Optional[dict[str, str]]: + traceparent = cls._build_traceparent() + if not traceparent: + return extra_headers + + headers_with_trace = dict(extra_headers) if extra_headers else {} + headers_with_trace["traceparent"] = traceparent + return headers_with_trace + + @classmethod + def _attach_traceparent_header(cls, kwargs: Dict[str, Any]) -> None: + updated_headers = cls._maybe_add_traceparent_header( + kwargs.get("extra_headers") + ) + if updated_headers is None: + kwargs.pop("extra_headers", None) + else: + kwargs["extra_headers"] = updated_headers + async def acompletion( self, model, messages, tools, **kwargs ) -> Union[ModelResponse, CustomStreamWrapper]: @@ -240,6 +274,8 @@ async def acompletion( The model response as a message. """ + self._attach_traceparent_header(kwargs) + return await acompletion( model=model, messages=messages, @@ -263,6 +299,8 @@ def completion( The response from the model. """ + self._attach_traceparent_header(kwargs) + return completion( model=model, messages=messages, diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 094fde774a..5059fa8d8b 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -37,6 +37,7 @@ from google.adk.models.lite_llm import LiteLLMClient from google.adk.models.lite_llm import TextChunk from google.adk.models.lite_llm import UsageMetadataChunk +import google.adk.models.lite_llm as lite_llm_module from google.adk.models.llm_request import LlmRequest from google.genai import types import litellm @@ -48,6 +49,9 @@ from litellm.types.utils import Delta from litellm.types.utils import ModelResponse from litellm.types.utils import StreamingChoices +from opentelemetry.trace import SpanContext +from opentelemetry.trace import TraceFlags +from opentelemetry.trace import TraceState from pydantic import BaseModel from pydantic import Field import pytest @@ -210,6 +214,139 @@ ] +class _StubSpan: + + def __init__(self, span_context): + self._span_context = span_context + + def get_span_context(self): + return self._span_context + + +def _build_valid_span_context(): + return SpanContext( + trace_id=int("0123456789abcdef0123456789abcdef", 16), + span_id=int("abcdef0123456789", 16), + is_remote=False, + trace_flags=TraceFlags(1), + trace_state=TraceState(), + ) + + +def _build_invalid_span_context(): + return SpanContext( + trace_id=0, + span_id=0, + is_remote=False, + trace_flags=TraceFlags(0), + trace_state=TraceState(), + ) + + +def test_maybe_add_traceparent_header_with_existing_headers(monkeypatch): + span_context = _build_valid_span_context() + monkeypatch.setattr( + lite_llm_module.trace, + "get_current_span", + lambda: _StubSpan(span_context), + ) + + headers = {"custom": "header"} + result = LiteLLMClient._maybe_add_traceparent_header(headers) + + assert result is not headers + assert result["custom"] == "header" + assert result["traceparent"] == ( + "00-0123456789abcdef0123456789abcdef-abcdef0123456789-01" + ) + + +def test_maybe_add_traceparent_header_without_existing_headers(monkeypatch): + span_context = _build_valid_span_context() + monkeypatch.setattr( + lite_llm_module.trace, + "get_current_span", + lambda: _StubSpan(span_context), + ) + + result = LiteLLMClient._maybe_add_traceparent_header(None) + + assert result == { + "traceparent": "00-0123456789abcdef0123456789abcdef-abcdef0123456789-01" + } + + +def test_maybe_add_traceparent_header_without_active_span(monkeypatch): + span_context = _build_invalid_span_context() + monkeypatch.setattr( + lite_llm_module.trace, + "get_current_span", + lambda: _StubSpan(span_context), + ) + + headers = {"custom": "value"} + result = LiteLLMClient._maybe_add_traceparent_header(headers) + + assert result is headers + + +@pytest.mark.asyncio +async def test_litellmclient_acompletion_sets_traceparent_header(monkeypatch): + async_mock = AsyncMock(return_value="response") + monkeypatch.setattr(lite_llm_module, "acompletion", async_mock) + + def fake_helper(headers): + assert headers == {"existing": "header"} + return {"existing": "header", "traceparent": "tp"} + + monkeypatch.setattr( + LiteLLMClient, "_maybe_add_traceparent_header", fake_helper + ) + + client = LiteLLMClient() + await client.acompletion( + model="test", + messages=[], + tools=None, + extra_headers={"existing": "header"}, + custom="value", + ) + + async_mock.assert_awaited_once() + _, kwargs = async_mock.call_args + assert kwargs["extra_headers"] == { + "existing": "header", + "traceparent": "tp", + } + assert kwargs["custom"] == "value" + + +def test_litellmclient_completion_sets_traceparent_header(monkeypatch): + sync_mock = Mock(return_value="response") + monkeypatch.setattr(lite_llm_module, "completion", sync_mock) + + def fake_helper(headers): + assert headers is None + return {"traceparent": "tp"} + + monkeypatch.setattr( + LiteLLMClient, "_maybe_add_traceparent_header", fake_helper + ) + + client = LiteLLMClient() + client.completion( + model="test", + messages=[], + tools=None, + stream=True, + ) + + sync_mock.assert_called_once() + _, kwargs = sync_mock.call_args + assert kwargs["extra_headers"] == {"traceparent": "tp"} + assert kwargs["stream"] + + class _StructuredOutput(BaseModel): value: int = Field(description="Value to emit")