66import json
77import logging
88import mimetypes
9+ import re
910import time
1011import uuid
1112import warnings
7475 ToolMessage ,
7576 is_data_content_block ,
7677)
77- from langchain_core .messages .ai import UsageMetadata , add_usage , subtract_usage
78+ from langchain_core .messages .ai import (
79+ UsageMetadata ,
80+ )
7881from langchain_core .messages .tool import invalid_tool_call , tool_call , tool_call_chunk
7982from langchain_core .output_parsers import JsonOutputParser , PydanticOutputParser
8083from langchain_core .output_parsers .base import OutputParserLike
@@ -798,6 +801,234 @@ def _extract_grounding_metadata(candidate: Any) -> Dict[str, Any]:
798801 return result
799802
800803
804+ def _sanitize_token_detail_key (raw_key : str ) -> str :
805+ """Convert provider detail labels into snake_case keys."""
806+ sanitized = re .sub (r"[^0-9a-zA-Z]+" , "_" , raw_key .strip ().lower ()).strip ("_" )
807+ return sanitized or "unknown"
808+
809+
810+ def _extract_token_detail_counts (
811+ entries : Sequence [Mapping [str , Any ]] | None ,
812+ * ,
813+ prefix : str | None = None ,
814+ ) -> dict [str , int ]:
815+ """Convert modality/token entries into a token detail mapping."""
816+ if not entries :
817+ return {}
818+ detail_counts : dict [str , int ] = {}
819+ for entry in entries :
820+ if not isinstance (entry , Mapping ):
821+ continue
822+ raw_key = entry .get ("modality" ) or entry .get ("type" ) or entry .get ("name" )
823+ if not raw_key :
824+ continue
825+ raw_value = (
826+ entry .get ("token_count" )
827+ or entry .get ("tokenCount" )
828+ or entry .get ("tokens_count" )
829+ or entry .get ("tokensCount" )
830+ or entry .get ("count" )
831+ )
832+ try :
833+ value_int = int (raw_value or 0 )
834+ except (TypeError , ValueError ):
835+ value_int = 0
836+ if value_int == 0 :
837+ continue
838+ key = _sanitize_token_detail_key (str (raw_key ))
839+ if prefix :
840+ key = f"{ prefix } { key } "
841+ detail_counts [key ] = detail_counts .get (key , 0 ) + value_int
842+ return detail_counts
843+
844+
845+ def _merge_detail_counts (target : dict [str , int ], new_entries : dict [str , int ]) -> None :
846+ """Accumulate modality detail counts into the provided target mapping."""
847+ for key , value in new_entries .items ():
848+ target [key ] = target .get (key , 0 ) + value
849+
850+
851+ def _usage_proto_to_dict (raw_usage : Any ) -> dict [str , Any ]:
852+ """Coerce proto UsageMetadata (or dict) into a plain dictionary."""
853+ if raw_usage is None :
854+ return {}
855+ if isinstance (raw_usage , Mapping ):
856+ return dict (raw_usage )
857+ try :
858+ return proto .Message .to_dict (raw_usage )
859+ except Exception : # pragma: no cover - best effort fallback
860+ try :
861+ return dict (raw_usage )
862+ except Exception : # pragma: no cover - final fallback
863+ return {}
864+
865+
866+ def _coerce_usage_metadata (raw_usage : Any ) -> Optional [UsageMetadata ]:
867+ """Normalize Gemini usage metadata into LangChain's UsageMetadata."""
868+ usage_dict = _usage_proto_to_dict (raw_usage )
869+ if not usage_dict :
870+ return None
871+
872+ def _get_int (name : str ) -> int :
873+ value = usage_dict .get (name )
874+ try :
875+ return int (value or 0 )
876+ except (TypeError , ValueError ):
877+ return 0
878+
879+ prompt_tokens = _get_int ("prompt_token_count" )
880+ response_tokens = (
881+ _get_int ("candidates_token_count" )
882+ or _get_int ("response_token_count" )
883+ or _get_int ("output_token_count" )
884+ )
885+ tool_prompt_tokens = _get_int ("tool_use_prompt_token_count" )
886+ reasoning_tokens = _get_int ("thoughts_token_count" ) or _get_int (
887+ "reasoning_token_count"
888+ )
889+ cache_read_tokens = _get_int ("cached_content_token_count" )
890+
891+ input_tokens = prompt_tokens + tool_prompt_tokens
892+ output_tokens = response_tokens + reasoning_tokens
893+ total_tokens = _get_int ("total_token_count" ) or _get_int ("total_tokens" )
894+ if total_tokens == 0 :
895+ total_tokens = input_tokens + output_tokens
896+ if total_tokens != input_tokens + output_tokens :
897+ total_tokens = input_tokens + output_tokens
898+
899+ input_details : dict [str , int ] = {}
900+ if cache_read_tokens :
901+ input_details ["cache_read" ] = cache_read_tokens
902+ if tool_prompt_tokens :
903+ input_details ["tool_use_prompt" ] = tool_prompt_tokens
904+ _merge_detail_counts (
905+ input_details ,
906+ _extract_token_detail_counts (
907+ usage_dict .get ("prompt_tokens_details" )
908+ or usage_dict .get ("promptTokensDetails" ),
909+ ),
910+ )
911+ _merge_detail_counts (
912+ input_details ,
913+ _extract_token_detail_counts (
914+ usage_dict .get ("tool_use_prompt_tokens_details" )
915+ or usage_dict .get ("toolUsePromptTokensDetails" ),
916+ prefix = "tool_use_prompt_" ,
917+ ),
918+ )
919+ _merge_detail_counts (
920+ input_details ,
921+ _extract_token_detail_counts (
922+ usage_dict .get ("cache_tokens_details" )
923+ or usage_dict .get ("cacheTokensDetails" ),
924+ prefix = "cache_" ,
925+ ),
926+ )
927+
928+ output_details : dict [str , int ] = {}
929+ if reasoning_tokens :
930+ output_details ["reasoning" ] = reasoning_tokens
931+ for key in (
932+ "candidates_tokens_details" ,
933+ "candidatesTokensDetails" ,
934+ "response_tokens_details" ,
935+ "responseTokensDetails" ,
936+ "output_tokens_details" ,
937+ "outputTokensDetails" ,
938+ "total_tokens_details" ,
939+ "totalTokensDetails" ,
940+ ):
941+ _merge_detail_counts (
942+ output_details , _extract_token_detail_counts (usage_dict .get (key ))
943+ )
944+
945+ for alt_key in ("thought" , "thoughts" , "reasoning_tokens" ):
946+ if alt_key in output_details :
947+ output_details ["reasoning" ] = output_details .get (
948+ "reasoning" , 0
949+ ) + output_details .pop (alt_key )
950+
951+ usage_payload : dict [str , Any ] = {
952+ "input_tokens" : input_tokens ,
953+ "output_tokens" : output_tokens ,
954+ "total_tokens" : total_tokens ,
955+ }
956+ if input_details :
957+ usage_payload ["input_token_details" ] = cast (Any , input_details )
958+ if output_details :
959+ usage_payload ["output_token_details" ] = cast (Any , output_details )
960+
961+ return cast (UsageMetadata , usage_payload )
962+
963+
964+ def _diff_token_details (
965+ current : Mapping [str , Any ] | None ,
966+ previous : Mapping [str , Any ] | None ,
967+ ) -> dict [str , int ]:
968+ """Compute detail deltas between cumulative usage payloads."""
969+ if not current and not previous :
970+ return {}
971+ current = current or {}
972+ previous = previous or {}
973+ diff : dict [str , int ] = {}
974+ for key in set (current ).union (previous ):
975+ current_value = current .get (key , 0 )
976+ previous_value = previous .get (key , 0 )
977+ if isinstance (current_value , Mapping ) or isinstance (previous_value , Mapping ):
978+ nested = _diff_token_details (
979+ current_value if isinstance (current_value , Mapping ) else None ,
980+ previous_value if isinstance (previous_value , Mapping ) else None ,
981+ )
982+ if nested :
983+ diff [key ] = nested # type: ignore[assignment]
984+ continue
985+ try :
986+ current_int = int (current_value or 0 )
987+ except (TypeError , ValueError ):
988+ current_int = 0
989+ try :
990+ previous_int = int (previous_value or 0 )
991+ except (TypeError , ValueError ):
992+ previous_int = 0
993+ delta = current_int - previous_int
994+ if delta != 0 :
995+ diff [key ] = delta
996+ return diff
997+
998+
999+ def _diff_usage_metadata (
1000+ current : UsageMetadata , previous : UsageMetadata
1001+ ) -> UsageMetadata :
1002+ """Return chunk-level usage delta between cumulative UsageMetadata values."""
1003+
1004+ input_delta = current .get ("input_tokens" , 0 ) - previous .get ("input_tokens" , 0 )
1005+ output_delta = current .get ("output_tokens" , 0 ) - previous .get ("output_tokens" , 0 )
1006+ total_delta = current .get ("total_tokens" , 0 ) - previous .get ("total_tokens" , 0 )
1007+ expected_total = input_delta + output_delta
1008+ if total_delta != expected_total :
1009+ total_delta = expected_total
1010+
1011+ diff_payload : dict [str , Any ] = {
1012+ "input_tokens" : input_delta ,
1013+ "output_tokens" : output_delta ,
1014+ "total_tokens" : total_delta ,
1015+ }
1016+
1017+ input_detail_delta = _diff_token_details (
1018+ current .get ("input_token_details" ), previous .get ("input_token_details" )
1019+ )
1020+ if input_detail_delta :
1021+ diff_payload ["input_token_details" ] = cast (Any , input_detail_delta )
1022+
1023+ output_detail_delta = _diff_token_details (
1024+ current .get ("output_token_details" ), previous .get ("output_token_details" )
1025+ )
1026+ if output_detail_delta :
1027+ diff_payload ["output_token_details" ] = cast (Any , output_detail_delta )
1028+
1029+ return cast (UsageMetadata , diff_payload )
1030+
1031+
8011032def _response_to_result (
8021033 response : GenerateContentResponse ,
8031034 stream : bool = False ,
@@ -806,47 +1037,16 @@ def _response_to_result(
8061037 """Converts a PaLM API response into a LangChain ChatResult."""
8071038 llm_output = {"prompt_feedback" : proto .Message .to_dict (response .prompt_feedback )}
8081039
809- # Get usage metadata
810- try :
811- input_tokens = response .usage_metadata .prompt_token_count
812- thought_tokens = response .usage_metadata .thoughts_token_count
813- output_tokens = response .usage_metadata .candidates_token_count + thought_tokens
814- total_tokens = response .usage_metadata .total_token_count
815- cache_read_tokens = response .usage_metadata .cached_content_token_count
816- if input_tokens + output_tokens + cache_read_tokens + total_tokens > 0 :
817- if thought_tokens > 0 :
818- cumulative_usage = UsageMetadata (
819- input_tokens = input_tokens ,
820- output_tokens = output_tokens ,
821- total_tokens = total_tokens ,
822- input_token_details = {"cache_read" : cache_read_tokens },
823- output_token_details = {"reasoning" : thought_tokens },
824- )
825- else :
826- cumulative_usage = UsageMetadata (
827- input_tokens = input_tokens ,
828- output_tokens = output_tokens ,
829- total_tokens = total_tokens ,
830- input_token_details = {"cache_read" : cache_read_tokens },
831- )
832- # previous usage metadata needs to be subtracted because gemini api returns
833- # already-accumulated token counts with each chunk
834- lc_usage = subtract_usage (cumulative_usage , prev_usage )
835- if prev_usage and cumulative_usage ["input_tokens" ] < prev_usage .get (
836- "input_tokens" , 0
837- ):
838- # Gemini 1.5 and 2.0 return a lower cumulative count of prompt tokens
839- # in the final chunk. We take this count to be ground truth because
840- # it's consistent with the reported total tokens. So we need to
841- # ensure this chunk compensates (the subtract_usage funcction floors
842- # at zero).
843- lc_usage ["input_tokens" ] = cumulative_usage [
844- "input_tokens"
845- ] - prev_usage .get ("input_tokens" , 0 )
846- else :
847- lc_usage = None
848- except AttributeError :
849- lc_usage = None
1040+ cumulative_usage = _coerce_usage_metadata (response .usage_metadata )
1041+ if cumulative_usage :
1042+ llm_output ["usage_metadata" ] = cumulative_usage
1043+
1044+ if stream and cumulative_usage and prev_usage :
1045+ lc_usage : Optional [UsageMetadata ] = _diff_usage_metadata (
1046+ cumulative_usage , prev_usage
1047+ )
1048+ else :
1049+ lc_usage = cumulative_usage
8501050
8511051 generations : List [ChatGeneration ] = []
8521052
@@ -1961,19 +2161,18 @@ def _stream(
19612161 metadata = self .default_metadata ,
19622162 )
19632163
1964- prev_usage_metadata : UsageMetadata | None = None # cumulative usage
2164+ prev_usage_metadata : UsageMetadata | None = None
19652165 for chunk in response :
19662166 _chat_result = _response_to_result (
19672167 chunk , stream = True , prev_usage = prev_usage_metadata
19682168 )
19692169 gen = cast ("ChatGenerationChunk" , _chat_result .generations [0 ])
1970- message = cast ("AIMessageChunk" , gen .message )
1971-
1972- prev_usage_metadata = (
1973- message .usage_metadata
1974- if prev_usage_metadata is None
1975- else add_usage (prev_usage_metadata , message .usage_metadata )
2170+ llm_output = _chat_result .llm_output or {}
2171+ cumulative_usage = cast (
2172+ Optional [UsageMetadata ], llm_output .get ("usage_metadata" )
19762173 )
2174+ if cumulative_usage is not None :
2175+ prev_usage_metadata = cumulative_usage
19772176
19782177 if run_manager :
19792178 run_manager .on_llm_new_token (gen .text , chunk = gen )
@@ -2024,7 +2223,7 @@ async def _astream(
20242223 kwargs ["timeout" ] = self .timeout
20252224 if "max_retries" not in kwargs :
20262225 kwargs ["max_retries" ] = self .max_retries
2027- prev_usage_metadata : UsageMetadata | None = None # cumulative usage
2226+ prev_usage_metadata : UsageMetadata | None = None
20282227 async for chunk in await _achat_with_retry (
20292228 request = request ,
20302229 generation_method = self .async_client .stream_generate_content ,
@@ -2035,13 +2234,12 @@ async def _astream(
20352234 chunk , stream = True , prev_usage = prev_usage_metadata
20362235 )
20372236 gen = cast ("ChatGenerationChunk" , _chat_result .generations [0 ])
2038- message = cast ("AIMessageChunk" , gen .message )
2039-
2040- prev_usage_metadata = (
2041- message .usage_metadata
2042- if prev_usage_metadata is None
2043- else add_usage (prev_usage_metadata , message .usage_metadata )
2237+ llm_output = _chat_result .llm_output or {}
2238+ cumulative_usage = cast (
2239+ Optional [UsageMetadata ], llm_output .get ("usage_metadata" )
20442240 )
2241+ if cumulative_usage is not None :
2242+ prev_usage_metadata = cumulative_usage
20452243
20462244 if run_manager :
20472245 await run_manager .on_llm_new_token (gen .text , chunk = gen )
0 commit comments