Skip to content

Commit c643d73

Browse files
committed
Fix Gemini usage metadata handling
1 parent 0d98d7a commit c643d73

File tree

6 files changed

+775
-96
lines changed

6 files changed

+775
-96
lines changed

libs/genai/langchain_google_genai/chat_models.py

Lines changed: 254 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import json
77
import logging
88
import mimetypes
9+
import re
910
import time
1011
import uuid
1112
import warnings
@@ -74,7 +75,9 @@
7475
ToolMessage,
7576
is_data_content_block,
7677
)
77-
from langchain_core.messages.ai import UsageMetadata, add_usage, subtract_usage
78+
from langchain_core.messages.ai import (
79+
UsageMetadata,
80+
)
7881
from langchain_core.messages.tool import invalid_tool_call, tool_call, tool_call_chunk
7982
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
8083
from langchain_core.output_parsers.base import OutputParserLike
@@ -798,6 +801,234 @@ def _extract_grounding_metadata(candidate: Any) -> Dict[str, Any]:
798801
return result
799802

800803

804+
def _sanitize_token_detail_key(raw_key: str) -> str:
805+
"""Convert provider detail labels into snake_case keys."""
806+
sanitized = re.sub(r"[^0-9a-zA-Z]+", "_", raw_key.strip().lower()).strip("_")
807+
return sanitized or "unknown"
808+
809+
810+
def _extract_token_detail_counts(
811+
entries: Sequence[Mapping[str, Any]] | None,
812+
*,
813+
prefix: str | None = None,
814+
) -> dict[str, int]:
815+
"""Convert modality/token entries into a token detail mapping."""
816+
if not entries:
817+
return {}
818+
detail_counts: dict[str, int] = {}
819+
for entry in entries:
820+
if not isinstance(entry, Mapping):
821+
continue
822+
raw_key = entry.get("modality") or entry.get("type") or entry.get("name")
823+
if not raw_key:
824+
continue
825+
raw_value = (
826+
entry.get("token_count")
827+
or entry.get("tokenCount")
828+
or entry.get("tokens_count")
829+
or entry.get("tokensCount")
830+
or entry.get("count")
831+
)
832+
try:
833+
value_int = int(raw_value or 0)
834+
except (TypeError, ValueError):
835+
value_int = 0
836+
if value_int == 0:
837+
continue
838+
key = _sanitize_token_detail_key(str(raw_key))
839+
if prefix:
840+
key = f"{prefix}{key}"
841+
detail_counts[key] = detail_counts.get(key, 0) + value_int
842+
return detail_counts
843+
844+
845+
def _merge_detail_counts(target: dict[str, int], new_entries: dict[str, int]) -> None:
846+
"""Accumulate modality detail counts into the provided target mapping."""
847+
for key, value in new_entries.items():
848+
target[key] = target.get(key, 0) + value
849+
850+
851+
def _usage_proto_to_dict(raw_usage: Any) -> dict[str, Any]:
852+
"""Coerce proto UsageMetadata (or dict) into a plain dictionary."""
853+
if raw_usage is None:
854+
return {}
855+
if isinstance(raw_usage, Mapping):
856+
return dict(raw_usage)
857+
try:
858+
return proto.Message.to_dict(raw_usage)
859+
except Exception: # pragma: no cover - best effort fallback
860+
try:
861+
return dict(raw_usage)
862+
except Exception: # pragma: no cover - final fallback
863+
return {}
864+
865+
866+
def _coerce_usage_metadata(raw_usage: Any) -> Optional[UsageMetadata]:
867+
"""Normalize Gemini usage metadata into LangChain's UsageMetadata."""
868+
usage_dict = _usage_proto_to_dict(raw_usage)
869+
if not usage_dict:
870+
return None
871+
872+
def _get_int(name: str) -> int:
873+
value = usage_dict.get(name)
874+
try:
875+
return int(value or 0)
876+
except (TypeError, ValueError):
877+
return 0
878+
879+
prompt_tokens = _get_int("prompt_token_count")
880+
response_tokens = (
881+
_get_int("candidates_token_count")
882+
or _get_int("response_token_count")
883+
or _get_int("output_token_count")
884+
)
885+
tool_prompt_tokens = _get_int("tool_use_prompt_token_count")
886+
reasoning_tokens = _get_int("thoughts_token_count") or _get_int(
887+
"reasoning_token_count"
888+
)
889+
cache_read_tokens = _get_int("cached_content_token_count")
890+
891+
input_tokens = prompt_tokens + tool_prompt_tokens
892+
output_tokens = response_tokens + reasoning_tokens
893+
total_tokens = _get_int("total_token_count") or _get_int("total_tokens")
894+
if total_tokens == 0:
895+
total_tokens = input_tokens + output_tokens
896+
if total_tokens != input_tokens + output_tokens:
897+
total_tokens = input_tokens + output_tokens
898+
899+
input_details: dict[str, int] = {}
900+
if cache_read_tokens:
901+
input_details["cache_read"] = cache_read_tokens
902+
if tool_prompt_tokens:
903+
input_details["tool_use_prompt"] = tool_prompt_tokens
904+
_merge_detail_counts(
905+
input_details,
906+
_extract_token_detail_counts(
907+
usage_dict.get("prompt_tokens_details")
908+
or usage_dict.get("promptTokensDetails"),
909+
),
910+
)
911+
_merge_detail_counts(
912+
input_details,
913+
_extract_token_detail_counts(
914+
usage_dict.get("tool_use_prompt_tokens_details")
915+
or usage_dict.get("toolUsePromptTokensDetails"),
916+
prefix="tool_use_prompt_",
917+
),
918+
)
919+
_merge_detail_counts(
920+
input_details,
921+
_extract_token_detail_counts(
922+
usage_dict.get("cache_tokens_details")
923+
or usage_dict.get("cacheTokensDetails"),
924+
prefix="cache_",
925+
),
926+
)
927+
928+
output_details: dict[str, int] = {}
929+
if reasoning_tokens:
930+
output_details["reasoning"] = reasoning_tokens
931+
for key in (
932+
"candidates_tokens_details",
933+
"candidatesTokensDetails",
934+
"response_tokens_details",
935+
"responseTokensDetails",
936+
"output_tokens_details",
937+
"outputTokensDetails",
938+
"total_tokens_details",
939+
"totalTokensDetails",
940+
):
941+
_merge_detail_counts(
942+
output_details, _extract_token_detail_counts(usage_dict.get(key))
943+
)
944+
945+
for alt_key in ("thought", "thoughts", "reasoning_tokens"):
946+
if alt_key in output_details:
947+
output_details["reasoning"] = output_details.get(
948+
"reasoning", 0
949+
) + output_details.pop(alt_key)
950+
951+
usage_payload: dict[str, Any] = {
952+
"input_tokens": input_tokens,
953+
"output_tokens": output_tokens,
954+
"total_tokens": total_tokens,
955+
}
956+
if input_details:
957+
usage_payload["input_token_details"] = cast(Any, input_details)
958+
if output_details:
959+
usage_payload["output_token_details"] = cast(Any, output_details)
960+
961+
return cast(UsageMetadata, usage_payload)
962+
963+
964+
def _diff_token_details(
965+
current: Mapping[str, Any] | None,
966+
previous: Mapping[str, Any] | None,
967+
) -> dict[str, int]:
968+
"""Compute detail deltas between cumulative usage payloads."""
969+
if not current and not previous:
970+
return {}
971+
current = current or {}
972+
previous = previous or {}
973+
diff: dict[str, int] = {}
974+
for key in set(current).union(previous):
975+
current_value = current.get(key, 0)
976+
previous_value = previous.get(key, 0)
977+
if isinstance(current_value, Mapping) or isinstance(previous_value, Mapping):
978+
nested = _diff_token_details(
979+
current_value if isinstance(current_value, Mapping) else None,
980+
previous_value if isinstance(previous_value, Mapping) else None,
981+
)
982+
if nested:
983+
diff[key] = nested # type: ignore[assignment]
984+
continue
985+
try:
986+
current_int = int(current_value or 0)
987+
except (TypeError, ValueError):
988+
current_int = 0
989+
try:
990+
previous_int = int(previous_value or 0)
991+
except (TypeError, ValueError):
992+
previous_int = 0
993+
delta = current_int - previous_int
994+
if delta != 0:
995+
diff[key] = delta
996+
return diff
997+
998+
999+
def _diff_usage_metadata(
1000+
current: UsageMetadata, previous: UsageMetadata
1001+
) -> UsageMetadata:
1002+
"""Return chunk-level usage delta between cumulative UsageMetadata values."""
1003+
1004+
input_delta = current.get("input_tokens", 0) - previous.get("input_tokens", 0)
1005+
output_delta = current.get("output_tokens", 0) - previous.get("output_tokens", 0)
1006+
total_delta = current.get("total_tokens", 0) - previous.get("total_tokens", 0)
1007+
expected_total = input_delta + output_delta
1008+
if total_delta != expected_total:
1009+
total_delta = expected_total
1010+
1011+
diff_payload: dict[str, Any] = {
1012+
"input_tokens": input_delta,
1013+
"output_tokens": output_delta,
1014+
"total_tokens": total_delta,
1015+
}
1016+
1017+
input_detail_delta = _diff_token_details(
1018+
current.get("input_token_details"), previous.get("input_token_details")
1019+
)
1020+
if input_detail_delta:
1021+
diff_payload["input_token_details"] = cast(Any, input_detail_delta)
1022+
1023+
output_detail_delta = _diff_token_details(
1024+
current.get("output_token_details"), previous.get("output_token_details")
1025+
)
1026+
if output_detail_delta:
1027+
diff_payload["output_token_details"] = cast(Any, output_detail_delta)
1028+
1029+
return cast(UsageMetadata, diff_payload)
1030+
1031+
8011032
def _response_to_result(
8021033
response: GenerateContentResponse,
8031034
stream: bool = False,
@@ -806,47 +1037,16 @@ def _response_to_result(
8061037
"""Converts a PaLM API response into a LangChain ChatResult."""
8071038
llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}
8081039

809-
# Get usage metadata
810-
try:
811-
input_tokens = response.usage_metadata.prompt_token_count
812-
thought_tokens = response.usage_metadata.thoughts_token_count
813-
output_tokens = response.usage_metadata.candidates_token_count + thought_tokens
814-
total_tokens = response.usage_metadata.total_token_count
815-
cache_read_tokens = response.usage_metadata.cached_content_token_count
816-
if input_tokens + output_tokens + cache_read_tokens + total_tokens > 0:
817-
if thought_tokens > 0:
818-
cumulative_usage = UsageMetadata(
819-
input_tokens=input_tokens,
820-
output_tokens=output_tokens,
821-
total_tokens=total_tokens,
822-
input_token_details={"cache_read": cache_read_tokens},
823-
output_token_details={"reasoning": thought_tokens},
824-
)
825-
else:
826-
cumulative_usage = UsageMetadata(
827-
input_tokens=input_tokens,
828-
output_tokens=output_tokens,
829-
total_tokens=total_tokens,
830-
input_token_details={"cache_read": cache_read_tokens},
831-
)
832-
# previous usage metadata needs to be subtracted because gemini api returns
833-
# already-accumulated token counts with each chunk
834-
lc_usage = subtract_usage(cumulative_usage, prev_usage)
835-
if prev_usage and cumulative_usage["input_tokens"] < prev_usage.get(
836-
"input_tokens", 0
837-
):
838-
# Gemini 1.5 and 2.0 return a lower cumulative count of prompt tokens
839-
# in the final chunk. We take this count to be ground truth because
840-
# it's consistent with the reported total tokens. So we need to
841-
# ensure this chunk compensates (the subtract_usage funcction floors
842-
# at zero).
843-
lc_usage["input_tokens"] = cumulative_usage[
844-
"input_tokens"
845-
] - prev_usage.get("input_tokens", 0)
846-
else:
847-
lc_usage = None
848-
except AttributeError:
849-
lc_usage = None
1040+
cumulative_usage = _coerce_usage_metadata(response.usage_metadata)
1041+
if cumulative_usage:
1042+
llm_output["usage_metadata"] = cumulative_usage
1043+
1044+
if stream and cumulative_usage and prev_usage:
1045+
lc_usage: Optional[UsageMetadata] = _diff_usage_metadata(
1046+
cumulative_usage, prev_usage
1047+
)
1048+
else:
1049+
lc_usage = cumulative_usage
8501050

8511051
generations: List[ChatGeneration] = []
8521052

@@ -1961,19 +2161,18 @@ def _stream(
19612161
metadata=self.default_metadata,
19622162
)
19632163

1964-
prev_usage_metadata: UsageMetadata | None = None # cumulative usage
2164+
prev_usage_metadata: UsageMetadata | None = None
19652165
for chunk in response:
19662166
_chat_result = _response_to_result(
19672167
chunk, stream=True, prev_usage=prev_usage_metadata
19682168
)
19692169
gen = cast("ChatGenerationChunk", _chat_result.generations[0])
1970-
message = cast("AIMessageChunk", gen.message)
1971-
1972-
prev_usage_metadata = (
1973-
message.usage_metadata
1974-
if prev_usage_metadata is None
1975-
else add_usage(prev_usage_metadata, message.usage_metadata)
2170+
llm_output = _chat_result.llm_output or {}
2171+
cumulative_usage = cast(
2172+
Optional[UsageMetadata], llm_output.get("usage_metadata")
19762173
)
2174+
if cumulative_usage is not None:
2175+
prev_usage_metadata = cumulative_usage
19772176

19782177
if run_manager:
19792178
run_manager.on_llm_new_token(gen.text, chunk=gen)
@@ -2024,7 +2223,7 @@ async def _astream(
20242223
kwargs["timeout"] = self.timeout
20252224
if "max_retries" not in kwargs:
20262225
kwargs["max_retries"] = self.max_retries
2027-
prev_usage_metadata: UsageMetadata | None = None # cumulative usage
2226+
prev_usage_metadata: UsageMetadata | None = None
20282227
async for chunk in await _achat_with_retry(
20292228
request=request,
20302229
generation_method=self.async_client.stream_generate_content,
@@ -2035,13 +2234,12 @@ async def _astream(
20352234
chunk, stream=True, prev_usage=prev_usage_metadata
20362235
)
20372236
gen = cast("ChatGenerationChunk", _chat_result.generations[0])
2038-
message = cast("AIMessageChunk", gen.message)
2039-
2040-
prev_usage_metadata = (
2041-
message.usage_metadata
2042-
if prev_usage_metadata is None
2043-
else add_usage(prev_usage_metadata, message.usage_metadata)
2237+
llm_output = _chat_result.llm_output or {}
2238+
cumulative_usage = cast(
2239+
Optional[UsageMetadata], llm_output.get("usage_metadata")
20442240
)
2241+
if cumulative_usage is not None:
2242+
prev_usage_metadata = cumulative_usage
20452243

20462244
if run_manager:
20472245
await run_manager.on_llm_new_token(gen.text, chunk=gen)

0 commit comments

Comments
 (0)