Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib

import utils


Expand Down
1 change: 0 additions & 1 deletion contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import flags
import experiment
from google.genai import types

import utils

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
75 changes: 52 additions & 23 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,12 +1023,15 @@ def _model_response_to_chunk(
if not func_name and not func_args:
continue

yield FunctionChunk(
id=tool_call.id,
name=func_name,
args=func_args,
index=func_index,
), finish_reason
yield (
FunctionChunk(
id=tool_call.id,
name=func_name,
args=func_args,
index=func_index,
),
finish_reason,
)

if finish_reason and not (message_content or tool_calls):
yield None, finish_reason
Expand All @@ -1040,12 +1043,17 @@ def _model_response_to_chunk(
# finish_reason set. But this is not the case we are observing from litellm.
# So we are sending it as a separate chunk to be set on the llm_response.
if response.get("usage", None):
yield UsageMetadataChunk(
prompt_tokens=response["usage"].get("prompt_tokens", 0),
completion_tokens=response["usage"].get("completion_tokens", 0),
total_tokens=response["usage"].get("total_tokens", 0),
cached_prompt_tokens=_extract_cached_prompt_tokens(response["usage"]),
), None
yield (
UsageMetadataChunk(
prompt_tokens=response["usage"].get("prompt_tokens", 0),
completion_tokens=response["usage"].get("completion_tokens", 0),
total_tokens=response["usage"].get("total_tokens", 0),
cached_prompt_tokens=_extract_cached_prompt_tokens(
response["usage"]
),
),
None,
)


def _model_response_to_generate_content_response(
Expand Down Expand Up @@ -1146,6 +1154,24 @@ def _message_to_generate_content_response(
)


def _enforce_closed_schema(schema: dict):
if not isinstance(schema, dict):
return

if schema.get("type") == "object":
schema.setdefault("additionalProperties", False)

for prop in schema.get("properties", {}).values():
_enforce_closed_schema(prop)

if "items" in schema:
_enforce_closed_schema(schema["items"])

if "$defs" in schema:
for def_schema in schema["$defs"].values():
_enforce_closed_schema(def_schema)


def _to_litellm_response_format(
response_schema: types.SchemaUnion,
model: str,
Expand Down Expand Up @@ -1206,14 +1232,9 @@ def _to_litellm_response_format(

# OpenAI-compatible format (default) per LiteLLM docs:
# https://docs.litellm.ai/docs/completion/json_mode
if (
isinstance(schema_dict, dict)
and schema_dict.get("type") == "object"
and "additionalProperties" not in schema_dict
):
if isinstance(schema_dict, dict):
# OpenAI structured outputs require explicit additionalProperties: false.
schema_dict = dict(schema_dict)
schema_dict["additionalProperties"] = False
_enforce_closed_schema(schema_dict)

return {
"type": "json_schema",
Expand Down Expand Up @@ -1433,7 +1454,12 @@ def _warn_gemini_via_litellm(model_string: str) -> None:
# Check if warning should be suppressed via environment variable
if os.environ.get(
"ADK_SUPPRESS_GEMINI_LITELLM_WARNINGS", ""
).strip().lower() in ("1", "true", "yes", "on"):
).strip().lower() in (
"1",
"true",
"yes",
"on",
):
return

warnings.warn(
Expand Down Expand Up @@ -1541,9 +1567,12 @@ async def generate_content_async(
logger.debug(_build_request_log(llm_request))

effective_model = llm_request.model or self.model
messages, tools, response_format, generation_params = (
await _get_completion_inputs(llm_request, effective_model)
)
(
messages,
tools,
response_format,
generation_params,
) = await _get_completion_inputs(llm_request, effective_model)
normalized_messages = _normalize_ollama_chat_messages(
messages,
model=effective_model,
Expand Down
Loading