11import json
2- from typing import Any , List , Literal , Tuple
2+ from typing import Any , List , Literal , Tuple , Optional
33
44import litellm
55from litellm ._logging import verbose_logger
1010async def calculate_batch_cost_and_usage (
1111 file_content_dictionary : List [dict ],
1212 custom_llm_provider : Literal ["openai" , "azure" , "vertex_ai" ],
13+ model_name : Optional [str ] = None ,
1314) -> Tuple [float , Usage , List [str ]]:
1415 """
1516 Calculate the cost and usage of a batch
1617 """
17- # Calculate costs and usage
1818 batch_cost = _batch_cost_calculator (
1919 custom_llm_provider = custom_llm_provider ,
2020 file_content_dictionary = file_content_dictionary ,
21+ model_name = model_name ,
2122 )
2223 batch_usage = _get_batch_job_total_usage_from_file_content (
2324 file_content_dictionary = file_content_dictionary ,
2425 custom_llm_provider = custom_llm_provider ,
26+ model_name = model_name ,
2527 )
26-
27- batch_models = _get_batch_models_from_file_content (file_content_dictionary )
28+ batch_models = _get_batch_models_from_file_content (file_content_dictionary , model_name )
2829
2930 return batch_cost , batch_usage , batch_models
3031
3132
3233async def _handle_completed_batch (
3334 batch : Batch ,
3435 custom_llm_provider : Literal ["openai" , "azure" , "vertex_ai" ],
36+ model_name : Optional [str ] = None ,
3537) -> Tuple [float , Usage , List [str ]]:
3638 """Helper function to process a completed batch and handle logging"""
3739 # Get batch results
@@ -43,23 +45,28 @@ async def _handle_completed_batch(
4345 batch_cost = _batch_cost_calculator (
4446 custom_llm_provider = custom_llm_provider ,
4547 file_content_dictionary = file_content_dictionary ,
48+ model_name = model_name ,
4649 )
4750 batch_usage = _get_batch_job_total_usage_from_file_content (
4851 file_content_dictionary = file_content_dictionary ,
4952 custom_llm_provider = custom_llm_provider ,
53+ model_name = model_name ,
5054 )
5155
52- batch_models = _get_batch_models_from_file_content (file_content_dictionary )
56+ batch_models = _get_batch_models_from_file_content (file_content_dictionary , model_name )
5357
5458 return batch_cost , batch_usage , batch_models
5559
5660
5761def _get_batch_models_from_file_content (
5862 file_content_dictionary : List [dict ],
63+ model_name : Optional [str ] = None ,
5964) -> List [str ]:
6065 """
6166 Get the models from the file content
6267 """
68+ if model_name :
69+ return [model_name ]
6370 batch_models = []
6471 for _item in file_content_dictionary :
6572 if _batch_response_was_successful (_item ):
@@ -73,12 +80,18 @@ def _get_batch_models_from_file_content(
7380def _batch_cost_calculator (
7481 file_content_dictionary : List [dict ],
7582 custom_llm_provider : Literal ["openai" , "azure" , "vertex_ai" ] = "openai" ,
83+ model_name : Optional [str ] = None ,
7684) -> float :
7785 """
7886 Calculate the cost of a batch based on the output file id
7987 """
80- if custom_llm_provider == "vertex_ai" :
81- raise ValueError ("Vertex AI does not support file content retrieval" )
88+ # Handle Vertex AI with specialized method
89+ if custom_llm_provider == "vertex_ai" and model_name :
90+ batch_cost , _ = calculate_vertex_ai_batch_cost_and_usage (file_content_dictionary , model_name )
91+ verbose_logger .debug ("vertex_ai_total_cost=%s" , batch_cost )
92+ return batch_cost
93+
94+ # For other providers, use the existing logic
8295 total_cost = _get_batch_job_cost_from_file_content (
8396 file_content_dictionary = file_content_dictionary ,
8497 custom_llm_provider = custom_llm_provider ,
@@ -87,6 +100,87 @@ def _batch_cost_calculator(
87100 return total_cost
88101
89102
103+ def calculate_vertex_ai_batch_cost_and_usage (
104+ vertex_ai_batch_responses : List [dict ],
105+ model_name : Optional [str ] = None ,
106+ ) -> Tuple [float , Usage ]:
107+ """
108+ Calculate both cost and usage from Vertex AI batch responses
109+ """
110+ total_cost = 0.0
111+ total_tokens = 0
112+ prompt_tokens = 0
113+ completion_tokens = 0
114+
115+ for response in vertex_ai_batch_responses :
116+ if response .get ("status" ) == "JOB_STATE_SUCCEEDED" : # Check if response was successful
117+ # Transform Vertex AI response to OpenAI format if needed
118+ from litellm .llms .vertex_ai .gemini .vertex_and_google_ai_studio_gemini import VertexGeminiConfig
119+ from litellm import ModelResponse
120+ from litellm .litellm_core_utils .litellm_logging import Logging
121+ from litellm .types .utils import CallTypes
122+ from litellm ._uuid import uuid
123+ import httpx
124+ import time
125+
126+ # Create required arguments for the transformation method
127+ model_response = ModelResponse ()
128+
129+ # Ensure model_name is not None
130+ actual_model_name = model_name or "gemini-2.5-flash"
131+
132+ # Create a real LiteLLM logging object
133+ logging_obj = Logging (
134+ model = actual_model_name ,
135+ messages = [{"role" : "user" , "content" : "batch_request" }],
136+ stream = False ,
137+ call_type = CallTypes .aretrieve_batch ,
138+ start_time = time .time (),
139+ litellm_call_id = "batch_" + str (uuid .uuid4 ()),
140+ function_id = "batch_processing" ,
141+ litellm_trace_id = str (uuid .uuid4 ()),
142+ kwargs = {"optional_params" : {}}
143+ )
144+
145+ # Add the optional_params attribute that the Vertex AI transformation expects
146+ logging_obj .optional_params = {}
147+ raw_response = httpx .Response (200 ) # Mock response object
148+
149+ openai_format_response = VertexGeminiConfig ()._transform_google_generate_content_to_openai_model_response (
150+ completion_response = response ["response" ],
151+ model_response = model_response ,
152+ model = actual_model_name ,
153+ logging_obj = logging_obj ,
154+ raw_response = raw_response ,
155+ )
156+
157+ # Calculate cost using existing function
158+ cost = litellm .completion_cost (
159+ completion_response = openai_format_response ,
160+ custom_llm_provider = "vertex_ai" ,
161+ call_type = CallTypes .aretrieve_batch .value ,
162+ )
163+ total_cost += cost
164+
165+ # Extract usage from the transformed response
166+ if hasattr (openai_format_response , 'usage' ) and openai_format_response .usage :
167+ usage = openai_format_response .usage
168+ else :
169+ # Fallback: create usage from response dict
170+ response_dict = openai_format_response .dict () if hasattr (openai_format_response , 'dict' ) else {}
171+ usage = _get_batch_job_usage_from_response_body (response_dict )
172+
173+ total_tokens += usage .total_tokens
174+ prompt_tokens += usage .prompt_tokens
175+ completion_tokens += usage .completion_tokens
176+
177+ return total_cost , Usage (
178+ total_tokens = total_tokens ,
179+ prompt_tokens = prompt_tokens ,
180+ completion_tokens = completion_tokens ,
181+ )
182+
183+
90184async def _get_batch_output_file_content_as_dictionary (
91185 batch : Batch ,
92186 custom_llm_provider : Literal ["openai" , "azure" , "vertex_ai" ] = "openai" ,
@@ -157,10 +251,17 @@ def _get_batch_job_cost_from_file_content(
157251def _get_batch_job_total_usage_from_file_content (
158252 file_content_dictionary : List [dict ],
159253 custom_llm_provider : Literal ["openai" , "azure" , "vertex_ai" ] = "openai" ,
254+ model_name : Optional [str ] = None ,
160255) -> Usage :
161256 """
162257 Get the tokens of a batch job from the file content
163258 """
259+ # Handle Vertex AI with specialized method
260+ if custom_llm_provider == "vertex_ai" and model_name :
261+ _ , batch_usage = calculate_vertex_ai_batch_cost_and_usage (file_content_dictionary , model_name )
262+ return batch_usage
263+
264+ # For other providers, use the existing logic
164265 total_tokens : int = 0
165266 prompt_tokens : int = 0
166267 completion_tokens : int = 0
0 commit comments