Skip to content

Commit 77e0bb0

Browse files
authored
Merge pull request #15744 from BerriAI/litellm_batch_gemini_passthrough_support2
Add Vertex AI Batch Passthrough Support with Cost Tracking
2 parents 08c376d + e1f2f1e commit 77e0bb0

File tree

7 files changed

+1118
-14
lines changed

7 files changed

+1118
-14
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import Tabs from '@theme/Tabs';
2+
import TabItem from '@theme/TabItem';
3+
4+
# /batchPredictionJobs
5+
6+
LiteLLM supports Vertex AI batch prediction jobs through passthrough endpoints, allowing you to create and manage batch jobs directly through the proxy server.
7+
8+
## Features
9+
10+
- **Batch Job Creation**: Create batch prediction jobs using Vertex AI models
11+
- **Cost Tracking**: Automatic cost calculation and usage tracking for batch operations
12+
- **Status Monitoring**: Track job status and retrieve results
13+
- **Model Support**: Works with all supported Vertex AI models (Gemini, Text Embedding)
14+
15+
## Cost Tracking Support
16+
17+
| Feature | Supported | Notes |
18+
|---------|-----------|-------|
19+
| Cost Tracking || Automatic cost calculation for batch operations |
20+
| Usage Monitoring || Track token usage and costs across batch jobs |
21+
| Logging || Supported |
22+
23+
## Quick Start
24+
25+
1. **Configure your model** in the proxy configuration:
26+
27+
```yaml
28+
model_list:
29+
- model_name: gemini-1.5-flash
30+
litellm_params:
31+
model: vertex_ai/gemini-1.5-flash
32+
vertex_project: your-project-id
33+
vertex_location: us-central1
34+
vertex_credentials: path/to/service-account.json
35+
```
36+
37+
2. **Create a batch job**:
38+
39+
```bash
40+
curl -X POST "http://localhost:4000/v1/projects/your-project/locations/us-central1/batchPredictionJobs" \
41+
-H "Authorization: Bearer your-api-key" \
42+
-H "Content-Type: application/json" \
43+
-d '{
44+
"displayName": "my-batch-job",
45+
"model": "projects/your-project/locations/us-central1/publishers/google/models/gemini-1.5-flash",
46+
"inputConfig": {
47+
"gcsSource": {
48+
"uris": ["gs://my-bucket/input.jsonl"]
49+
},
50+
"instancesFormat": "jsonl"
51+
},
52+
"outputConfig": {
53+
"gcsDestination": {
54+
"outputUriPrefix": "gs://my-bucket/output/"
55+
},
56+
"predictionsFormat": "jsonl"
57+
}
58+
}'
59+
```
60+
61+
3. **Monitor job status**:
62+
63+
```bash
64+
curl -X GET "http://localhost:4000/v1/projects/your-project/locations/us-central1/batchPredictionJobs/job-id" \
65+
-H "Authorization: Bearer your-api-key"
66+
```
67+
68+
## Model Configuration
69+
70+
When configuring models for batch operations, use these naming conventions:
71+
72+
- **`model_name`**: Base model name (e.g., `gemini-1.5-flash`)
73+
- **`model`**: Full LiteLLM identifier (e.g., `vertex_ai/gemini-1.5-flash`)
74+
75+
## Supported Models
76+
77+
- `gemini-1.5-flash` / `vertex_ai/gemini-1.5-flash`
78+
- `gemini-1.5-pro` / `vertex_ai/gemini-1.5-pro`
79+
- `gemini-2.0-flash` / `vertex_ai/gemini-2.0-flash`
80+
- `gemini-2.0-pro` / `vertex_ai/gemini-2.0-pro`
81+
82+
## Advanced Usage
83+
84+
### Batch Job with Custom Parameters
85+
86+
```bash
87+
curl -X POST "http://localhost:4000/v1/projects/your-project/locations/us-central1/batchPredictionJobs" \
88+
-H "Authorization: Bearer your-api-key" \
89+
-H "Content-Type: application/json" \
90+
-d '{
91+
"displayName": "advanced-batch-job",
92+
"model": "projects/your-project/locations/us-central1/publishers/google/models/gemini-1.5-pro",
93+
"inputConfig": {
94+
"gcsSource": {
95+
"uris": ["gs://my-bucket/advanced-input.jsonl"]
96+
},
97+
"instancesFormat": "jsonl"
98+
},
99+
"outputConfig": {
100+
"gcsDestination": {
101+
"outputUriPrefix": "gs://my-bucket/advanced-output/"
102+
},
103+
"predictionsFormat": "jsonl"
104+
},
105+
"labels": {
106+
"environment": "production",
107+
"team": "ml-engineering"
108+
}
109+
}'
110+
```
111+
112+
### List All Batch Jobs
113+
114+
```bash
115+
curl -X GET "http://localhost:4000/v1/projects/your-project/locations/us-central1/batchPredictionJobs" \
116+
-H "Authorization: Bearer your-api-key"
117+
```
118+
119+
### Cancel a Batch Job
120+
121+
```bash
122+
curl -X POST "http://localhost:4000/v1/projects/your-project/locations/us-central1/batchPredictionJobs/job-id:cancel" \
123+
-H "Authorization: Bearer your-api-key"
124+
```
125+
126+
## Cost Tracking Details
127+
128+
LiteLLM provides comprehensive cost tracking for Vertex AI batch operations:
129+
130+
- **Token Usage**: Tracks input and output tokens for each batch request
131+
- **Cost Calculation**: Automatically calculates costs based on current Vertex AI pricing
132+
- **Usage Aggregation**: Aggregates costs across all requests in a batch job
133+
- **Real-time Monitoring**: Monitor costs as batch jobs progress
134+
135+
The cost tracking works seamlessly with the `generateContent` API and provides detailed insights into your batch processing expenses.
136+
137+
## Error Handling
138+
139+
Common error scenarios and their solutions:
140+
141+
| Error | Description | Solution |
142+
|-------|-------------|----------|
143+
| `INVALID_ARGUMENT` | Invalid model or configuration | Verify model name and project settings |
144+
| `PERMISSION_DENIED` | Insufficient permissions | Check Vertex AI IAM roles |
145+
| `RESOURCE_EXHAUSTED` | Quota exceeded | Check Vertex AI quotas and limits |
146+
| `NOT_FOUND` | Job or resource not found | Verify job ID and project configuration |
147+
148+
## Best Practices
149+
150+
1. **Use appropriate batch sizes**: Balance between processing efficiency and resource usage
151+
2. **Monitor job status**: Regularly check job status to handle failures promptly
152+
3. **Set up alerts**: Configure monitoring for job completion and failures
153+
4. **Optimize costs**: Use cost tracking to identify optimization opportunities
154+
5. **Test with small batches**: Validate your setup with small test batches first
155+
156+
## Related Documentation
157+
158+
- [Vertex AI Provider Documentation](./vertex.md)
159+
- [General Batches API Documentation](../batches.md)
160+
- [Cost Tracking and Monitoring](../observability/telemetry.md)

enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ async def check_batch_cost(self):
5757
"file_purpose": "batch",
5858
}
5959
)
60-
6160
completed_jobs = []
6261

6362
for job in jobs:
@@ -139,7 +138,7 @@ async def check_batch_cost(self):
139138
custom_llm_provider = deployment_info.litellm_params.custom_llm_provider
140139
litellm_model_name = deployment_info.litellm_params.model
141140

142-
_, llm_provider, _, _ = get_llm_provider(
141+
model_name, llm_provider, _, _ = get_llm_provider(
143142
model=litellm_model_name,
144143
custom_llm_provider=custom_llm_provider,
145144
)
@@ -148,9 +147,9 @@ async def check_batch_cost(self):
148147
await calculate_batch_cost_and_usage(
149148
file_content_dictionary=file_content_as_dict,
150149
custom_llm_provider=llm_provider, # type: ignore
150+
model_name=model_name,
151151
)
152152
)
153-
154153
logging_obj = LiteLLMLogging(
155154
model=batch_models[0],
156155
messages=[{"role": "user", "content": "<retrieve_batch>"}],

litellm/batches/batch_utils.py

Lines changed: 108 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, List, Literal, Tuple
2+
from typing import Any, List, Literal, Tuple, Optional
33

44
import litellm
55
from litellm._logging import verbose_logger
@@ -10,28 +10,30 @@
1010
async def calculate_batch_cost_and_usage(
1111
file_content_dictionary: List[dict],
1212
custom_llm_provider: Literal["openai", "azure", "vertex_ai"],
13+
model_name: Optional[str] = None,
1314
) -> Tuple[float, Usage, List[str]]:
1415
"""
1516
Calculate the cost and usage of a batch
1617
"""
17-
# Calculate costs and usage
1818
batch_cost = _batch_cost_calculator(
1919
custom_llm_provider=custom_llm_provider,
2020
file_content_dictionary=file_content_dictionary,
21+
model_name=model_name,
2122
)
2223
batch_usage = _get_batch_job_total_usage_from_file_content(
2324
file_content_dictionary=file_content_dictionary,
2425
custom_llm_provider=custom_llm_provider,
26+
model_name=model_name,
2527
)
26-
27-
batch_models = _get_batch_models_from_file_content(file_content_dictionary)
28+
batch_models = _get_batch_models_from_file_content(file_content_dictionary, model_name)
2829

2930
return batch_cost, batch_usage, batch_models
3031

3132

3233
async def _handle_completed_batch(
3334
batch: Batch,
3435
custom_llm_provider: Literal["openai", "azure", "vertex_ai"],
36+
model_name: Optional[str] = None,
3537
) -> Tuple[float, Usage, List[str]]:
3638
"""Helper function to process a completed batch and handle logging"""
3739
# Get batch results
@@ -43,23 +45,28 @@ async def _handle_completed_batch(
4345
batch_cost = _batch_cost_calculator(
4446
custom_llm_provider=custom_llm_provider,
4547
file_content_dictionary=file_content_dictionary,
48+
model_name=model_name,
4649
)
4750
batch_usage = _get_batch_job_total_usage_from_file_content(
4851
file_content_dictionary=file_content_dictionary,
4952
custom_llm_provider=custom_llm_provider,
53+
model_name=model_name,
5054
)
5155

52-
batch_models = _get_batch_models_from_file_content(file_content_dictionary)
56+
batch_models = _get_batch_models_from_file_content(file_content_dictionary, model_name)
5357

5458
return batch_cost, batch_usage, batch_models
5559

5660

5761
def _get_batch_models_from_file_content(
5862
file_content_dictionary: List[dict],
63+
model_name: Optional[str] = None,
5964
) -> List[str]:
6065
"""
6166
Get the models from the file content
6267
"""
68+
if model_name:
69+
return [model_name]
6370
batch_models = []
6471
for _item in file_content_dictionary:
6572
if _batch_response_was_successful(_item):
@@ -73,12 +80,18 @@ def _get_batch_models_from_file_content(
7380
def _batch_cost_calculator(
7481
file_content_dictionary: List[dict],
7582
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
83+
model_name: Optional[str] = None,
7684
) -> float:
7785
"""
7886
Calculate the cost of a batch based on the output file id
7987
"""
80-
if custom_llm_provider == "vertex_ai":
81-
raise ValueError("Vertex AI does not support file content retrieval")
88+
# Handle Vertex AI with specialized method
89+
if custom_llm_provider == "vertex_ai" and model_name:
90+
batch_cost, _ = calculate_vertex_ai_batch_cost_and_usage(file_content_dictionary, model_name)
91+
verbose_logger.debug("vertex_ai_total_cost=%s", batch_cost)
92+
return batch_cost
93+
94+
# For other providers, use the existing logic
8295
total_cost = _get_batch_job_cost_from_file_content(
8396
file_content_dictionary=file_content_dictionary,
8497
custom_llm_provider=custom_llm_provider,
@@ -87,6 +100,87 @@ def _batch_cost_calculator(
87100
return total_cost
88101

89102

103+
def calculate_vertex_ai_batch_cost_and_usage(
104+
vertex_ai_batch_responses: List[dict],
105+
model_name: Optional[str] = None,
106+
) -> Tuple[float, Usage]:
107+
"""
108+
Calculate both cost and usage from Vertex AI batch responses
109+
"""
110+
total_cost = 0.0
111+
total_tokens = 0
112+
prompt_tokens = 0
113+
completion_tokens = 0
114+
115+
for response in vertex_ai_batch_responses:
116+
if response.get("status") == "JOB_STATE_SUCCEEDED": # Check if response was successful
117+
# Transform Vertex AI response to OpenAI format if needed
118+
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig
119+
from litellm import ModelResponse
120+
from litellm.litellm_core_utils.litellm_logging import Logging
121+
from litellm.types.utils import CallTypes
122+
from litellm._uuid import uuid
123+
import httpx
124+
import time
125+
126+
# Create required arguments for the transformation method
127+
model_response = ModelResponse()
128+
129+
# Ensure model_name is not None
130+
actual_model_name = model_name or "gemini-2.5-flash"
131+
132+
# Create a real LiteLLM logging object
133+
logging_obj = Logging(
134+
model=actual_model_name,
135+
messages=[{"role": "user", "content": "batch_request"}],
136+
stream=False,
137+
call_type=CallTypes.aretrieve_batch,
138+
start_time=time.time(),
139+
litellm_call_id="batch_" + str(uuid.uuid4()),
140+
function_id="batch_processing",
141+
litellm_trace_id=str(uuid.uuid4()),
142+
kwargs={"optional_params": {}}
143+
)
144+
145+
# Add the optional_params attribute that the Vertex AI transformation expects
146+
logging_obj.optional_params = {}
147+
raw_response = httpx.Response(200) # Mock response object
148+
149+
openai_format_response = VertexGeminiConfig()._transform_google_generate_content_to_openai_model_response(
150+
completion_response=response["response"],
151+
model_response=model_response,
152+
model=actual_model_name,
153+
logging_obj=logging_obj,
154+
raw_response=raw_response,
155+
)
156+
157+
# Calculate cost using existing function
158+
cost = litellm.completion_cost(
159+
completion_response=openai_format_response,
160+
custom_llm_provider="vertex_ai",
161+
call_type=CallTypes.aretrieve_batch.value,
162+
)
163+
total_cost += cost
164+
165+
# Extract usage from the transformed response
166+
if hasattr(openai_format_response, 'usage') and openai_format_response.usage:
167+
usage = openai_format_response.usage
168+
else:
169+
# Fallback: create usage from response dict
170+
response_dict = openai_format_response.dict() if hasattr(openai_format_response, 'dict') else {}
171+
usage = _get_batch_job_usage_from_response_body(response_dict)
172+
173+
total_tokens += usage.total_tokens
174+
prompt_tokens += usage.prompt_tokens
175+
completion_tokens += usage.completion_tokens
176+
177+
return total_cost, Usage(
178+
total_tokens=total_tokens,
179+
prompt_tokens=prompt_tokens,
180+
completion_tokens=completion_tokens,
181+
)
182+
183+
90184
async def _get_batch_output_file_content_as_dictionary(
91185
batch: Batch,
92186
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
@@ -157,10 +251,17 @@ def _get_batch_job_cost_from_file_content(
157251
def _get_batch_job_total_usage_from_file_content(
158252
file_content_dictionary: List[dict],
159253
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
254+
model_name: Optional[str] = None,
160255
) -> Usage:
161256
"""
162257
Get the tokens of a batch job from the file content
163258
"""
259+
# Handle Vertex AI with specialized method
260+
if custom_llm_provider == "vertex_ai" and model_name:
261+
_, batch_usage = calculate_vertex_ai_batch_cost_and_usage(file_content_dictionary, model_name)
262+
return batch_usage
263+
264+
# For other providers, use the existing logic
164265
total_tokens: int = 0
165266
prompt_tokens: int = 0
166267
completion_tokens: int = 0

0 commit comments

Comments
 (0)