Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from litellm import Message
from litellm import ModelResponse
from litellm import OpenAIMessageContent
from opentelemetry import trace
from pydantic import BaseModel
from pydantic import Field
from typing_extensions import override
Expand Down Expand Up @@ -225,6 +226,39 @@ class UsageMetadataChunk(BaseModel):
class LiteLLMClient:
"""Provides acompletion method (for better testability)."""

@staticmethod
def _build_traceparent() -> Optional[str]:
span_context = trace.get_current_span().get_span_context()
if not span_context.is_valid:
return None

trace_id = f"{span_context.trace_id:032x}"
span_id = f"{span_context.span_id:016x}"
trace_flags = f"{int(span_context.trace_flags):02x}"
return f"00-{trace_id}-{span_id}-{trace_flags}"

@classmethod
def _maybe_add_traceparent_header(
cls, extra_headers: Optional[dict[str, str]]
) -> Optional[dict[str, str]]:
traceparent = cls._build_traceparent()
if not traceparent:
return extra_headers

headers_with_trace = dict(extra_headers) if extra_headers else {}
headers_with_trace["traceparent"] = traceparent
return headers_with_trace

@classmethod
def _attach_traceparent_header(cls, kwargs: Dict[str, Any]) -> None:
updated_headers = cls._maybe_add_traceparent_header(
kwargs.get("extra_headers")
)
if updated_headers is None:
kwargs.pop("extra_headers", None)
else:
kwargs["extra_headers"] = updated_headers

async def acompletion(
self, model, messages, tools, **kwargs
) -> Union[ModelResponse, CustomStreamWrapper]:
Expand All @@ -240,6 +274,8 @@ async def acompletion(
The model response as a message.
"""

self._attach_traceparent_header(kwargs)

return await acompletion(
model=model,
messages=messages,
Expand All @@ -263,6 +299,8 @@ def completion(
The response from the model.
"""

self._attach_traceparent_header(kwargs)

return completion(
model=model,
messages=messages,
Expand Down
137 changes: 137 additions & 0 deletions tests/unittests/models/test_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from google.adk.models.lite_llm import LiteLLMClient
from google.adk.models.lite_llm import TextChunk
from google.adk.models.lite_llm import UsageMetadataChunk
import google.adk.models.lite_llm as lite_llm_module
from google.adk.models.llm_request import LlmRequest
from google.genai import types
import litellm
Expand All @@ -48,6 +49,9 @@
from litellm.types.utils import Delta
from litellm.types.utils import ModelResponse
from litellm.types.utils import StreamingChoices
from opentelemetry.trace import SpanContext
from opentelemetry.trace import TraceFlags
from opentelemetry.trace import TraceState
from pydantic import BaseModel
from pydantic import Field
import pytest
Expand Down Expand Up @@ -210,6 +214,139 @@
]


class _StubSpan:

def __init__(self, span_context):
self._span_context = span_context

def get_span_context(self):
return self._span_context


def _build_valid_span_context():
return SpanContext(
trace_id=int("0123456789abcdef0123456789abcdef", 16),
span_id=int("abcdef0123456789", 16),
is_remote=False,
trace_flags=TraceFlags(1),
trace_state=TraceState(),
)


def _build_invalid_span_context():
return SpanContext(
trace_id=0,
span_id=0,
is_remote=False,
trace_flags=TraceFlags(0),
trace_state=TraceState(),
)


def test_maybe_add_traceparent_header_with_existing_headers(monkeypatch):
span_context = _build_valid_span_context()
monkeypatch.setattr(
lite_llm_module.trace,
"get_current_span",
lambda: _StubSpan(span_context),
)

headers = {"custom": "header"}
result = LiteLLMClient._maybe_add_traceparent_header(headers)

assert result is not headers
assert result["custom"] == "header"
assert result["traceparent"] == (
"00-0123456789abcdef0123456789abcdef-abcdef0123456789-01"
)


def test_maybe_add_traceparent_header_without_existing_headers(monkeypatch):
span_context = _build_valid_span_context()
monkeypatch.setattr(
lite_llm_module.trace,
"get_current_span",
lambda: _StubSpan(span_context),
)

result = LiteLLMClient._maybe_add_traceparent_header(None)

assert result == {
"traceparent": "00-0123456789abcdef0123456789abcdef-abcdef0123456789-01"
}


def test_maybe_add_traceparent_header_without_active_span(monkeypatch):
span_context = _build_invalid_span_context()
monkeypatch.setattr(
lite_llm_module.trace,
"get_current_span",
lambda: _StubSpan(span_context),
)

headers = {"custom": "value"}
result = LiteLLMClient._maybe_add_traceparent_header(headers)

assert result is headers


@pytest.mark.asyncio
async def test_litellmclient_acompletion_sets_traceparent_header(monkeypatch):
async_mock = AsyncMock(return_value="response")
monkeypatch.setattr(lite_llm_module, "acompletion", async_mock)

def fake_helper(headers):
assert headers == {"existing": "header"}
return {"existing": "header", "traceparent": "tp"}

monkeypatch.setattr(
LiteLLMClient, "_maybe_add_traceparent_header", fake_helper
)

client = LiteLLMClient()
await client.acompletion(
model="test",
messages=[],
tools=None,
extra_headers={"existing": "header"},
custom="value",
)

async_mock.assert_awaited_once()
_, kwargs = async_mock.call_args
assert kwargs["extra_headers"] == {
"existing": "header",
"traceparent": "tp",
}
assert kwargs["custom"] == "value"


def test_litellmclient_completion_sets_traceparent_header(monkeypatch):
sync_mock = Mock(return_value="response")
monkeypatch.setattr(lite_llm_module, "completion", sync_mock)

def fake_helper(headers):
assert headers is None
return {"traceparent": "tp"}

monkeypatch.setattr(
LiteLLMClient, "_maybe_add_traceparent_header", fake_helper
)

client = LiteLLMClient()
client.completion(
model="test",
messages=[],
tools=None,
stream=True,
)

sync_mock.assert_called_once()
_, kwargs = sync_mock.call_args
assert kwargs["extra_headers"] == {"traceparent": "tp"}
assert kwargs["stream"]


class _StructuredOutput(BaseModel):
value: int = Field(description="Value to emit")

Expand Down