diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index dc34b56c3ffc..6c4783329bd3 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -101,6 +101,7 @@ general_settings: disable_retry_on_max_parallel_request_limit_error: boolean # turn off retries when max parallel request limit is reached disable_reset_budget: boolean # turn off reset budget scheduled task disable_adding_master_key_hash_to_db: boolean # turn off storing master key hash in db, for spend tracking + disable_responses_id_security: boolean # turn off response ID security checks that prevent users from accessing other users' responses enable_jwt_auth: boolean # allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims enforce_user_param: boolean # requires all openai endpoint requests to have a 'user' param allowed_routes: ["route1", "route2"] # list of allowed proxy API routes - a user can access. (currently JWT-Auth only) @@ -197,6 +198,7 @@ router_settings: | disable_retry_on_max_parallel_request_limit_error | boolean | If true, turns off retries when max parallel request limit is reached | | disable_reset_budget | boolean | If true, turns off reset budget scheduled task | | disable_adding_master_key_hash_to_db | boolean | If true, turns off storing master key hash in db | +| disable_responses_id_security | boolean | If true, disables response ID security checks that prevent users from accessing response IDs from other users. When false (default), response IDs are encrypted with user information to ensure users can only access their own responses. Applies to /v1/responses endpoints | | enable_jwt_auth | boolean | allow proxy admin to auth in via jwt tokens with 'litellm_proxy_admin' in claims. [Doc on JWT Tokens](token_auth) | | enforce_user_param | boolean | If true, requires all OpenAI endpoint requests to have a 'user' param. [Doc on call hooks](call_hooks)| | allowed_routes | array of strings | List of allowed proxy API routes a user can access [Doc on controlling allowed routes](enterprise#control-available-public-private-routes)| diff --git a/docs/my-website/docs/response_api.md b/docs/my-website/docs/response_api.md index fe21266958fc..dfc02e3b34d0 100644 --- a/docs/my-website/docs/response_api.md +++ b/docs/my-website/docs/response_api.md @@ -699,6 +699,32 @@ for event in response: +## Response ID Security + +By default, LiteLLM Proxy prevents users from accessing other users' response IDs. + +This is done by encrypting the response ID with the user ID, enabling users to only access their own response IDs. + +Trying to access someone else's response ID returns 403: + +```json +{ + "error": { + "message": "Forbidden. The response id is not associated with the user, who this key belongs to.", + "code": 403 + } +} +``` + +To disable this, set `disable_responses_id_security: true`: + +```yaml +general_settings: + disable_responses_id_security: true +``` + +This allows any user to access any response ID. + ## Supported Responses API Parameters | Provider | Supported Parameters | diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 83c5a49edf27..3d13cc513526 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -182,20 +182,20 @@ def _get_cost_breakdown_from_logging_obj( ) -> Tuple[Optional[float], Optional[float]]: """ Extract discount information from logging object's cost breakdown. - + Returns: Tuple of (original_cost, discount_amount) """ if not litellm_logging_obj or not hasattr(litellm_logging_obj, "cost_breakdown"): return None, None - + cost_breakdown = litellm_logging_obj.cost_breakdown if not cost_breakdown: return None, None - + original_cost = cost_breakdown.get("original_cost") discount_amount = cost_breakdown.get("discount_amount") - + return original_cost, discount_amount @@ -223,12 +223,12 @@ def get_custom_headers( ) -> dict: exclude_values = {"", None, "None"} hidden_params = hidden_params or {} - + # Extract discount info from cost_breakdown if available original_cost, discount_amount = _get_cost_breakdown_from_logging_obj( litellm_logging_obj=litellm_logging_obj ) - + headers = { "x-litellm-call-id": call_id, "x-litellm-model-id": model_id, @@ -239,8 +239,12 @@ def get_custom_headers( "x-litellm-version": version, "x-litellm-model-region": model_region, "x-litellm-response-cost": str(response_cost), - "x-litellm-response-cost-original": str(original_cost) if original_cost is not None else None, - "x-litellm-response-cost-discount-amount": str(discount_amount) if discount_amount is not None else None, + "x-litellm-response-cost-original": ( + str(original_cost) if original_cost is not None else None + ), + "x-litellm-response-cost-discount-amount": ( + str(discount_amount) if discount_amount is not None else None + ), "x-litellm-key-tpm-limit": str(user_api_key_dict.tpm_limit), "x-litellm-key-rpm-limit": str(user_api_key_dict.rpm_limit), "x-litellm-key-max-budget": str(user_api_key_dict.max_budget), @@ -327,6 +331,7 @@ async def common_processing_pre_call_logic( model: Optional[str] = None, ) -> Tuple[dict, LiteLLMLoggingObj]: start_time = datetime.now() # start before calling guardrail hooks + self.data = await add_litellm_data_to_request( data=self.data, request=request, @@ -790,7 +795,9 @@ async def async_sse_data_generator( verbose_proxy_logger.debug("inside generator") try: str_so_far = "" - async for chunk in proxy_logging_obj.async_post_call_streaming_iterator_hook( + async for ( + chunk + ) in proxy_logging_obj.async_post_call_streaming_iterator_hook( user_api_key_dict=user_api_key_dict, response=response, request_data=request_data, @@ -812,7 +819,11 @@ async def async_sse_data_generator( # Inject cost into Anthropic-style SSE usage for /v1/messages for any provider model_name = request_data.get("model", "") - chunk = ProxyBaseLLMRequestProcessing._process_chunk_with_cost_injection(chunk, model_name) + chunk = ( + ProxyBaseLLMRequestProcessing._process_chunk_with_cost_injection( + chunk, model_name + ) + ) # Format chunk using helper function yield ProxyBaseLLMRequestProcessing.return_sse_chunk(chunk) @@ -850,52 +861,72 @@ async def async_sse_data_generator( def _process_chunk_with_cost_injection(chunk: Any, model_name: str) -> Any: """ Process a streaming chunk and inject cost information if enabled. - + Args: chunk: The streaming chunk (dict, str, bytes, or bytearray) model_name: Model name for cost calculation - + Returns: The processed chunk with cost information injected if applicable """ if not getattr(litellm, "include_cost_in_streaming_usage", False): return chunk - + try: if isinstance(chunk, dict): - maybe_modified = ProxyBaseLLMRequestProcessing._inject_cost_into_usage_dict(chunk, model_name) + maybe_modified = ( + ProxyBaseLLMRequestProcessing._inject_cost_into_usage_dict( + chunk, model_name + ) + ) if maybe_modified is not None: return maybe_modified elif isinstance(chunk, (bytes, bytearray)): # Decode to str, inject, and rebuild as bytes try: s = chunk.decode("utf-8", errors="ignore") - maybe_mod = ProxyBaseLLMRequestProcessing._inject_cost_into_sse_frame_str(s, model_name) + maybe_mod = ( + ProxyBaseLLMRequestProcessing._inject_cost_into_sse_frame_str( + s, model_name + ) + ) if maybe_mod is not None: - return (maybe_mod + ("" if maybe_mod.endswith("\n\n") else "\n\n")).encode("utf-8") + return ( + maybe_mod + ("" if maybe_mod.endswith("\n\n") else "\n\n") + ).encode("utf-8") except Exception: pass elif isinstance(chunk, str): # Try to parse SSE frame and inject cost into the data line - maybe_mod = ProxyBaseLLMRequestProcessing._inject_cost_into_sse_frame_str(chunk, model_name) + maybe_mod = ( + ProxyBaseLLMRequestProcessing._inject_cost_into_sse_frame_str( + chunk, model_name + ) + ) if maybe_mod is not None: # Ensure trailing frame separator - return maybe_mod if maybe_mod.endswith("\n\n") else (maybe_mod + "\n\n") + return ( + maybe_mod + if maybe_mod.endswith("\n\n") + else (maybe_mod + "\n\n") + ) except Exception: # Never break streaming on optional cost injection pass - + return chunk - + @staticmethod - def _inject_cost_into_sse_frame_str(frame_str: str, model_name: str) -> Optional[str]: + def _inject_cost_into_sse_frame_str( + frame_str: str, model_name: str + ) -> Optional[str]: """ Inject cost information into an SSE frame string by modifying the JSON in the 'data:' line. - + Args: frame_str: SSE frame string that may contain multiple lines model_name: Model name for cost calculation - + Returns: Modified SSE frame string with cost injected, or None if no modification needed """ @@ -908,7 +939,11 @@ def _inject_cost_into_sse_frame_str(frame_str: str, model_name: str) -> Optional json_part = stripped_ln.split("data:", 1)[1].strip() if json_part and json_part != "[DONE]": obj = json.loads(json_part) - maybe_modified = ProxyBaseLLMRequestProcessing._inject_cost_into_usage_dict(obj, model_name) + maybe_modified = ( + ProxyBaseLLMRequestProcessing._inject_cost_into_usage_dict( + obj, model_name + ) + ) if maybe_modified is not None: # Replace just this line with updated JSON using safe_dumps lines[idx] = f"data: {safe_dumps(maybe_modified)}" @@ -916,23 +951,20 @@ def _inject_cost_into_sse_frame_str(frame_str: str, model_name: str) -> Optional return None except Exception: return None - + @staticmethod def _inject_cost_into_usage_dict(obj: dict, model_name: str) -> Optional[dict]: """ Inject cost information into a usage dictionary for message_delta events. - + Args: obj: Dictionary containing the SSE event data model_name: Model name for cost calculation - + Returns: Modified dictionary with cost injected, or None if no modification needed """ - if ( - obj.get("type") == "message_delta" - and isinstance(obj.get("usage"), dict) - ): + if obj.get("type") == "message_delta" and isinstance(obj.get("usage"), dict): _usage = obj["usage"] prompt_tokens = int(_usage.get("input_tokens", 0) or 0) completion_tokens = int(_usage.get("output_tokens", 0) or 0) @@ -948,35 +980,34 @@ def _inject_cost_into_usage_dict(obj: dict, model_name: str) -> Optional[dict]: completion_tokens_details = _usage.get("completion_tokens_details") prompt_tokens_details = _usage.get("prompt_tokens_details") - usage_kwargs: dict[str, Any] = { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens, } - + # Add optional named parameters if completion_tokens_details is not None: usage_kwargs["completion_tokens_details"] = completion_tokens_details if prompt_tokens_details is not None: usage_kwargs["prompt_tokens_details"] = prompt_tokens_details - + # Handle web_search_requests by wrapping in ServerToolUse if web_search_requests is not None: usage_kwargs["server_tool_use"] = ServerToolUse( web_search_requests=web_search_requests ) - + # Add cache-related fields to **params (handled by Usage.__init__) if cache_creation_input_tokens is not None: - usage_kwargs["cache_creation_input_tokens"] = cache_creation_input_tokens + usage_kwargs["cache_creation_input_tokens"] = ( + cache_creation_input_tokens + ) if cache_read_input_tokens is not None: usage_kwargs["cache_read_input_tokens"] = cache_read_input_tokens - _mr = ModelResponse( - usage=Usage(**usage_kwargs) - ) - + _mr = ModelResponse(usage=Usage(**usage_kwargs)) + try: cost_val = litellm.completion_cost( completion_response=_mr, @@ -984,8 +1015,8 @@ def _inject_cost_into_usage_dict(obj: dict, model_name: str) -> Optional[dict]: ) except Exception: cost_val = None - + if cost_val is not None: obj.setdefault("usage", {})["cost"] = cost_val return obj - return None \ No newline at end of file + return None diff --git a/litellm/proxy/common_utils/encrypt_decrypt_utils.py b/litellm/proxy/common_utils/encrypt_decrypt_utils.py index c3b7b55054ad..6c47d220c4f4 100644 --- a/litellm/proxy/common_utils/encrypt_decrypt_utils.py +++ b/litellm/proxy/common_utils/encrypt_decrypt_utils.py @@ -22,7 +22,8 @@ def encrypt_value_helper(value: str, new_encryption_key: Optional[str] = None): try: if isinstance(value, str): encrypted_value = encrypt_value(value=value, signing_key=signing_key) # type: ignore - encrypted_value = base64.b64encode(encrypted_value).decode("utf-8") + # Use urlsafe_b64encode for URL-safe base64 encoding (replaces + with - and / with _) + encrypted_value = base64.urlsafe_b64encode(encrypted_value).decode("utf-8") return encrypted_value @@ -45,7 +46,14 @@ def decrypt_value_helper( try: if isinstance(value, str): - decoded_b64 = base64.b64decode(value) + # Try URL-safe base64 decoding first (new format) + # Fall back to standard base64 decoding for backwards compatibility (old format) + try: + decoded_b64 = base64.urlsafe_b64decode(value) + except Exception: + # If URL-safe decoding fails, try standard base64 decoding for backwards compatibility + decoded_b64 = base64.b64decode(value) + value = decrypt_value(value=decoded_b64, signing_key=signing_key) # type: ignore return value diff --git a/litellm/proxy/hooks/__init__.py b/litellm/proxy/hooks/__init__.py index 467565d748a0..0f8f995bc2c9 100644 --- a/litellm/proxy/hooks/__init__.py +++ b/litellm/proxy/hooks/__init__.py @@ -6,6 +6,7 @@ from .max_budget_limiter import _PROXY_MaxBudgetLimiter from .parallel_request_limiter import _PROXY_MaxParallelRequestsHandler from .parallel_request_limiter_v3 import _PROXY_MaxParallelRequestsHandler_v3 +from .responses_id_security import ResponsesIDSecurity ### CHECK IF ENTERPRISE HOOKS ARE AVAILABLE ### @@ -19,6 +20,7 @@ "max_budget_limiter": _PROXY_MaxBudgetLimiter, "parallel_request_limiter": _PROXY_MaxParallelRequestsHandler_v3, "cache_control_check": _PROXY_CacheControlCheck, + "responses_id_security": ResponsesIDSecurity, } ## FEATURE FLAG HOOKS ## @@ -40,7 +42,7 @@ def get_proxy_hook( "cache_control_check", ], str, - ] + ], ): """ Factory method to get a proxy hook instance by name diff --git a/litellm/proxy/hooks/responses_id_security.py b/litellm/proxy/hooks/responses_id_security.py new file mode 100644 index 000000000000..cfc2e10f40b7 --- /dev/null +++ b/litellm/proxy/hooks/responses_id_security.py @@ -0,0 +1,268 @@ +""" +Security hook to prevent user B from seeing response from user A. + +This hook uses the DBSpendUpdateWriter to batch-write response IDs to the database +instead of writing immediately on each request. +""" + +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Literal, + Optional, + Tuple, + Union, + cast, +) + +from fastapi import HTTPException + +from litellm._logging import verbose_proxy_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import LitellmUserRoles +from litellm.proxy.common_utils.encrypt_decrypt_utils import ( + decrypt_value_helper, + encrypt_value_helper, +) +from litellm.types.llms.openai import ( + BaseLiteLLMOpenAIResponseObject, + ResponsesAPIResponse, +) +from litellm.types.utils import LLMResponseTypes, SpecialEnums + +if TYPE_CHECKING: + from litellm.caching.caching import DualCache + from litellm.proxy._types import UserAPIKeyAuth + + +class ResponsesIDSecurity(CustomLogger): + def __init__(self): + pass + + async def async_pre_call_hook( + self, + user_api_key_dict: "UserAPIKeyAuth", + cache: "DualCache", + data: dict, + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "pass_through_endpoint", + "rerank", + "mcp_call", + "anthropic_messages", + ], + ) -> Optional[Union[Exception, str, dict]]: + # MAP all the responses api response ids to the encrypted response ids + responses_api_call_types = { + "aresponses", + "aget_responses", + "adelete_responses", + "acancel_responses", + } + if call_type not in responses_api_call_types: + return None + if call_type == "aresponses": + # check 'previous_response_id' if present in the data + previous_response_id = data.get("previous_response_id") + if previous_response_id and self._is_encrypted_response_id( + previous_response_id + ): + original_response_id, user_id, team_id = self._decrypt_response_id( + previous_response_id + ) + self.check_user_access_to_response_id( + user_id, team_id, user_api_key_dict + ) + data["previous_response_id"] = original_response_id + elif call_type in {"aget_responses", "adelete_responses", "acancel_responses"}: + response_id = data.get("response_id") + + if response_id and self._is_encrypted_response_id(response_id): + original_response_id, user_id, team_id = self._decrypt_response_id( + response_id + ) + + self.check_user_access_to_response_id( + user_id, team_id, user_api_key_dict + ) + data["response_id"] = original_response_id + return data + + def check_user_access_to_response_id( + self, + response_id_user_id: Optional[str], + response_id_team_id: Optional[str], + user_api_key_dict: "UserAPIKeyAuth", + ) -> bool: + from litellm.proxy.proxy_server import general_settings + + if ( + user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value + or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN + ): + return True + + if response_id_user_id and response_id_user_id != user_api_key_dict.user_id: + if general_settings.get("disable_responses_id_security", False): + verbose_proxy_logger.debug( + f"Responses ID Security is disabled. User {user_api_key_dict.user_id} is accessing response id {response_id_user_id} which is not associated with them." + ) + return True + raise HTTPException( + status_code=403, + detail="Forbidden. The response id is not associated with the user, who this key belongs to. To disable this security feature, set general_settings::disable_responses_id_security to True in the config.yaml file.", + ) + + if response_id_team_id and response_id_team_id != user_api_key_dict.team_id: + if general_settings.get("disable_responses_id_security", False): + verbose_proxy_logger.debug( + f"Responses ID Security is disabled. Response belongs to team {response_id_team_id} but user {user_api_key_dict.user_id} is accessing it with team id {user_api_key_dict.team_id}." + ) + return True + raise HTTPException( + status_code=403, + detail="Forbidden. The response id is not associated with the team, who this key belongs to. To disable this security feature, set general_settings::disable_responses_id_security to True in the config.yaml file.", + ) + + return True + + def _is_encrypted_response_id(self, response_id: str) -> bool: + + remaining_string = response_id.split("resp_")[1] + decrypted_value = decrypt_value_helper( + value=remaining_string, key="response_id", return_original_value=True + ) + + if decrypted_value is None: + return False + + if decrypted_value.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value): + return True + return False + + def _decrypt_response_id( + self, response_id: str + ) -> Tuple[str, Optional[str], Optional[str]]: + """ + Returns: + - original_response_id: the original response id + - user_id: the user id + - team_id: the team id + """ + remaining_string = response_id.split("resp_")[1] + decrypted_value = decrypt_value_helper( + value=remaining_string, key="response_id", return_original_value=True + ) + + if decrypted_value is None: + return response_id, None, None + + if decrypted_value.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value): + # Expected format: "litellm_proxy:responses_api:response_id:{response_id};user_id:{user_id}" + parts = decrypted_value.split(";") + + if len(parts) >= 2: + # Extract response_id from "litellm_proxy:responses_api:response_id:{response_id}" + response_id_part = parts[0] + original_response_id = response_id_part.split("response_id:")[-1] + + # Extract user_id from "user_id:{user_id}" + user_id_part = parts[1] + user_id = user_id_part.split("user_id:")[-1] + + # Extract team_id from "team_id:{team_id}" + team_id_part = parts[2] + team_id = team_id_part.split("team_id:")[-1] + + return original_response_id, user_id, team_id + else: + # Fallback if format is unexpected + return response_id, None, None + return response_id, None, None + + def _encrypt_response_id( + self, + response: BaseLiteLLMOpenAIResponseObject, + user_api_key_dict: "UserAPIKeyAuth", + ) -> BaseLiteLLMOpenAIResponseObject: + # encrypt the response id using the symmetric key + # encrypt the response id, and encode the user id and response id in base64 + response_id = getattr(response, "id", None) + response_obj = getattr(response, "response", None) + + if ( + response_id + and isinstance(response_id, str) + and response_id.startswith("resp_") + ): + encrypted_response_id = SpecialEnums.LITELLM_MANAGED_RESPONSE_API_RESPONSE_ID_COMPLETE_STR.value.format( + response_id, + user_api_key_dict.user_id or "", + user_api_key_dict.team_id or "", + ) + + encoded_user_id_and_response_id = encrypt_value_helper( + value=encrypted_response_id + ) + setattr( + response, "id", f"resp_{encoded_user_id_and_response_id}" + ) # maintain the 'resp_' prefix for the responses api response id + + elif response_obj and isinstance(response_obj, ResponsesAPIResponse): + encrypted_response_id = SpecialEnums.LITELLM_MANAGED_RESPONSE_API_RESPONSE_ID_COMPLETE_STR.value.format( + response_obj.id, + user_api_key_dict.user_id or "", + user_api_key_dict.team_id or "", + ) + encoded_user_id_and_response_id = encrypt_value_helper( + value=encrypted_response_id + ) + setattr( + response_obj, "id", f"resp_{encoded_user_id_and_response_id}" + ) # maintain the 'resp_' prefix for the responses api response id + setattr(response, "response", response_obj) + return response + + async def async_post_call_success_hook( + self, + data: dict, + user_api_key_dict: "UserAPIKeyAuth", + response: LLMResponseTypes, + ) -> Any: + """ + Queue response IDs for batch processing instead of writing directly to DB. + + This method adds response IDs to an in-memory queue, which are then + batch-processed by the DBSpendUpdateWriter during regular database update cycles. + """ + from litellm.proxy.proxy_server import general_settings + + if general_settings.get("disable_responses_id_security", False): + return response + if isinstance(response, ResponsesAPIResponse): + response = cast( + ResponsesAPIResponse, + self._encrypt_response_id(response, user_api_key_dict), + ) + return response + + async def async_post_call_streaming_iterator_hook( # type: ignore + self, user_api_key_dict: "UserAPIKeyAuth", response: Any, request_data: dict + ) -> AsyncGenerator[BaseLiteLLMOpenAIResponseObject, None]: + from litellm.proxy.proxy_server import general_settings + + async for chunk in response: + if ( + isinstance(chunk, BaseLiteLLMOpenAIResponseObject) + and user_api_key_dict.request_route + == "/v1/responses" # only encrypt the response id for the responses api + and not general_settings.get("disable_responses_id_security", False) + ): + chunk = self._encrypt_response_id(chunk, user_api_key_dict) + yield chunk diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0bd5ffc584b7..eba2caa4c2a5 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -488,9 +488,9 @@ def generate_feedback_box(): server_root_path = os.getenv("SERVER_ROOT_PATH", "") _license_check = LicenseCheck() premium_user: bool = _license_check.is_premium() -premium_user_data: Optional[ - "EnterpriseLicenseData" -] = _license_check.airgapped_license_data +premium_user_data: Optional["EnterpriseLicenseData"] = ( + _license_check.airgapped_license_data +) global_max_parallel_request_retries_env: Optional[str] = os.getenv( "LITELLM_GLOBAL_MAX_PARALLEL_REQUEST_RETRIES" ) @@ -1017,9 +1017,9 @@ def swagger_monkey_patch(*args, **kwargs): master_key: Optional[str] = None otel_logging = False prisma_client: Optional[PrismaClient] = None -shared_aiohttp_session: Optional[ - "ClientSession" -] = None # Global shared session for connection reuse +shared_aiohttp_session: Optional["ClientSession"] = ( + None # Global shared session for connection reuse +) user_api_key_cache = DualCache( default_in_memory_ttl=UserAPIKeyCacheTTLEnum.in_memory_cache_ttl.value ) @@ -1027,9 +1027,9 @@ def swagger_monkey_patch(*args, **kwargs): dual_cache=user_api_key_cache ) litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter) -redis_usage_cache: Optional[ - RedisCache -] = None # redis cache used for tracking spend, tpm/rpm limits +redis_usage_cache: Optional[RedisCache] = ( + None # redis cache used for tracking spend, tpm/rpm limits +) user_custom_auth = None user_custom_key_generate = None user_custom_sso = None @@ -1362,9 +1362,9 @@ async def _update_team_cache(): _id = "team_id:{}".format(team_id) try: # Fetch the existing cost for the given user - existing_spend_obj: Optional[ - LiteLLM_TeamTable - ] = await user_api_key_cache.async_get_cache(key=_id) + existing_spend_obj: Optional[LiteLLM_TeamTable] = ( + await user_api_key_cache.async_get_cache(key=_id) + ) if existing_spend_obj is None: # do nothing if team not in api key cache return @@ -3455,10 +3455,10 @@ async def _init_guardrails_in_db(self, prisma_client: PrismaClient): ) try: - guardrails_in_db: List[ - Guardrail - ] = await GuardrailRegistry.get_all_guardrails_from_db( - prisma_client=prisma_client + guardrails_in_db: List[Guardrail] = ( + await GuardrailRegistry.get_all_guardrails_from_db( + prisma_client=prisma_client + ) ) verbose_proxy_logger.debug( "guardrails from the DB %s", str(guardrails_in_db) @@ -3725,9 +3725,9 @@ async def initialize( # noqa: PLR0915 user_api_base = api_base dynamic_config[user_model]["api_base"] = api_base if api_version: - os.environ[ - "AZURE_API_VERSION" - ] = api_version # set this for azure - litellm can read this from the env + os.environ["AZURE_API_VERSION"] = ( + api_version # set this for azure - litellm can read this from the env + ) if max_tokens: # model-specific param dynamic_config[user_model]["max_tokens"] = max_tokens if temperature: # model-specific param @@ -9095,9 +9095,9 @@ async def get_config_list( hasattr(sub_field_info, "description") and sub_field_info.description is not None ): - nested_fields[ - idx - ].field_description = sub_field_info.description + nested_fields[idx].field_description = ( + sub_field_info.description + ) idx += 1 _stored_in_db = None diff --git a/litellm/types/utils.py b/litellm/types/utils.py index ba5b54d38a1d..964047fc57bc 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -2763,6 +2763,10 @@ class SpecialEnums(Enum): LITELLM_MANAGED_BATCH_COMPLETE_STR = "litellm_proxy;model_id:{};llm_batch_id:{}" + LITELLM_MANAGED_RESPONSE_API_RESPONSE_ID_COMPLETE_STR = ( + "litellm_proxy:responses_api:response_id:{};user_id:{};team_id:{}" + ) + LITELLM_MANAGED_GENERIC_RESPONSE_COMPLETE_STR = "litellm_proxy;model_id:{};generic_response_id:{}" # generic implementation of 'managed batches' - used for finetuning and any future work. diff --git a/litellm/utils.py b/litellm/utils.py index b36872b0e128..87936d919c64 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4419,6 +4419,33 @@ def _count_characters(text: str) -> int: def get_response_string(response_obj: Union[ModelResponse, ModelResponseStream]) -> str: + # Handle Responses API streaming events + if hasattr(response_obj, "type") and hasattr(response_obj, "response"): + # This is a Responses API streaming event (e.g., ResponseCreatedEvent, ResponseCompletedEvent) + # Extract text from the response object's output if available + responses_api_response = getattr(response_obj, "response", None) + if responses_api_response and hasattr(responses_api_response, "output"): + output_list = responses_api_response.output + response_str = "" + for output_item in output_list: + # Handle output items with content array + if hasattr(output_item, "content"): + for content_part in output_item.content: + if hasattr(content_part, "text"): + response_str += content_part.text + # Handle output items with direct text field + elif hasattr(output_item, "text"): + response_str += output_item.text + return response_str + + # Handle Responses API text delta events + if hasattr(response_obj, "type") and hasattr(response_obj, "delta"): + event_type = getattr(response_obj, "type", "") + if "text.delta" in event_type or "output_text.delta" in event_type: + delta = getattr(response_obj, "delta", "") + return delta if isinstance(delta, str) else "" + + # Handle standard ModelResponse and ModelResponseStream _choices: Union[List[Union[Choices, StreamingChoices]], List[StreamingChoices]] = ( response_obj.choices ) diff --git a/tests/test_litellm/test_responses_id_security.py b/tests/test_litellm/test_responses_id_security.py new file mode 100644 index 000000000000..16e57f73cf1e --- /dev/null +++ b/tests/test_litellm/test_responses_id_security.py @@ -0,0 +1,544 @@ +""" +Tests for ResponsesIDSecurity hook. + +Tests the security hook that prevents user B from seeing response from user A. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException + +from litellm.proxy.hooks.responses_id_security import ResponsesIDSecurity +from litellm.types.llms.openai import ResponsesAPIResponse +from litellm.types.utils import SpecialEnums + + +@pytest.fixture +def responses_id_security(): + """Fixture that creates a ResponsesIDSecurity instance.""" + return ResponsesIDSecurity() + + +@pytest.fixture +def mock_user_api_key_dict(): + """Fixture that creates a mock UserAPIKeyAuth object.""" + mock_auth = MagicMock() + mock_auth.user_id = "test-user-123" + mock_auth.team_id = "test-team-123" + mock_auth.token = "test-token" + mock_auth.user_role = None + return mock_auth + + +@pytest.fixture +def mock_cache(): + """Fixture that creates a mock DualCache object.""" + return MagicMock() + + +class TestIsEncryptedResponseId: + """Test _is_encrypted_response_id function""" + + def test_is_encrypted_response_id_valid(self, responses_id_security): + """Test that a properly encrypted response ID is identified correctly""" + with patch( + "litellm.proxy.hooks.responses_id_security.decrypt_value_helper" + ) as mock_decrypt: + mock_decrypt.return_value = f"{SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value}response_id:resp_123;user_id:user-456" + + result = responses_id_security._is_encrypted_response_id( + "resp_encrypted_value" + ) + + assert result is True + mock_decrypt.assert_called_once() + + def test_is_encrypted_response_id_invalid(self, responses_id_security): + """Test that an unencrypted response ID returns False""" + with patch( + "litellm.proxy.hooks.responses_id_security.decrypt_value_helper" + ) as mock_decrypt: + mock_decrypt.return_value = None + + result = responses_id_security._is_encrypted_response_id("resp_plain_value") + + assert result is False + + +class TestDecryptResponseId: + """Test _decrypt_response_id function""" + + def test_decrypt_response_id_valid(self, responses_id_security): + """Test decrypting a valid encrypted response ID""" + with patch( + "litellm.proxy.hooks.responses_id_security.decrypt_value_helper" + ) as mock_decrypt: + mock_decrypt.return_value = f"{SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value}response_id:resp_original_123;user_id:user-456;team_id:team-789" + + original_id, user_id, team_id = responses_id_security._decrypt_response_id( + "resp_encrypted_value" + ) + + assert original_id == "resp_original_123" + assert user_id == "user-456" + assert team_id == "team-789" + + def test_decrypt_response_id_no_encryption(self, responses_id_security): + """Test decrypting a non-encrypted response ID""" + with patch( + "litellm.proxy.hooks.responses_id_security.decrypt_value_helper" + ) as mock_decrypt: + mock_decrypt.return_value = None + + original_id, user_id, team_id = responses_id_security._decrypt_response_id( + "resp_plain_value" + ) + + assert original_id == "resp_plain_value" + assert user_id is None + assert team_id is None + + +class TestEncryptResponseId: + """Test _encrypt_response_id function""" + + def test_encrypt_response_id_success( + self, responses_id_security, mock_user_api_key_dict + ): + """Test encrypting a response ID with user information""" + mock_response = ResponsesAPIResponse( + id="resp_123", created_at=1234567890, output=[], status="completed" + ) + + with patch( + "litellm.proxy.hooks.responses_id_security.encrypt_value_helper" + ) as mock_encrypt: + mock_encrypt.return_value = "encrypted_base64_value" + + result = responses_id_security._encrypt_response_id( + mock_response, mock_user_api_key_dict + ) + + assert result.id == "resp_encrypted_base64_value" + assert result.id.startswith("resp_") + mock_encrypt.assert_called_once() + + def test_encrypt_response_id_maintains_prefix( + self, responses_id_security, mock_user_api_key_dict + ): + """Test that encrypted response ID maintains 'resp_' prefix""" + mock_response = ResponsesAPIResponse( + id="resp_456", created_at=1234567890, output=[], status="in_progress" + ) + + with patch( + "litellm.proxy.hooks.responses_id_security.encrypt_value_helper" + ) as mock_encrypt: + mock_encrypt.return_value = "encrypted_value_456" + + result = responses_id_security._encrypt_response_id( + mock_response, mock_user_api_key_dict + ) + + assert result.id.startswith("resp_") + + +class TestCheckUserAccessToResponseId: + """Test check_user_access_to_response_id function""" + + def test_check_user_access_same_user( + self, responses_id_security, mock_user_api_key_dict + ): + """Test that same user has access to their response ID""" + result = responses_id_security.check_user_access_to_response_id( + response_id_user_id="test-user-123", + response_id_team_id="test-team-123", + user_api_key_dict=mock_user_api_key_dict, + ) + + assert result is True + + def test_check_user_access_different_user_raises_exception( + self, responses_id_security, mock_user_api_key_dict + ): + """Test that different user is denied access to response ID""" + with patch("litellm.proxy.proxy_server.general_settings", {}): + with pytest.raises(HTTPException) as exc_info: + responses_id_security.check_user_access_to_response_id( + response_id_user_id="different-user-456", + response_id_team_id="test-team-123", + user_api_key_dict=mock_user_api_key_dict, + ) + + assert exc_info.value.status_code == 403 + assert "Forbidden" in exc_info.value.detail + + def test_check_user_access_different_team_raises_exception( + self, responses_id_security, mock_user_api_key_dict + ): + """Test that different team is denied access to response ID""" + with patch("litellm.proxy.proxy_server.general_settings", {}): + with pytest.raises(HTTPException) as exc_info: + responses_id_security.check_user_access_to_response_id( + response_id_user_id=None, + response_id_team_id="different-team-456", + user_api_key_dict=mock_user_api_key_dict, + ) + + assert exc_info.value.status_code == 403 + assert "Forbidden" in exc_info.value.detail + + def test_check_user_access_team_a_to_team_b_without_user_id( + self, responses_id_security + ): + """Test that key from team A (without user_id) cannot access response from team B (without user_id)""" + # Create a mock user from team A without user_id + mock_auth_team_a = MagicMock() + mock_auth_team_a.user_id = None + mock_auth_team_a.team_id = "team-a" + mock_auth_team_a.user_role = None + + with patch("litellm.proxy.proxy_server.general_settings", {}): + with pytest.raises(HTTPException) as exc_info: + responses_id_security.check_user_access_to_response_id( + response_id_user_id=None, + response_id_team_id="team-b", + user_api_key_dict=mock_auth_team_a, + ) + + assert exc_info.value.status_code == 403 + assert "team" in exc_info.value.detail.lower() + + def test_check_user_access_team_a_to_team_b_with_user_id( + self, responses_id_security + ): + """Test that key from team A (without user_id) cannot access response from team B (with user_id)""" + # Create a mock user from team A without user_id + mock_auth_team_a = MagicMock() + mock_auth_team_a.user_id = None + mock_auth_team_a.team_id = "team-a" + mock_auth_team_a.user_role = None + + with patch("litellm.proxy.proxy_server.general_settings", {}): + with pytest.raises(HTTPException) as exc_info: + responses_id_security.check_user_access_to_response_id( + response_id_user_id="user-from-team-b", + response_id_team_id="team-b", + user_api_key_dict=mock_auth_team_a, + ) + + # Access should be denied with 403. Could fail on user_id or team_id check. + assert exc_info.value.status_code == 403 + assert "forbidden" in exc_info.value.detail.lower() + + def test_check_user_access_same_team_without_user_id(self, responses_id_security): + """Test that key from team A (without user_id) can access response from same team A (without user_id)""" + # Create a mock user from team A without user_id + mock_auth_team_a = MagicMock() + mock_auth_team_a.user_id = None + mock_auth_team_a.team_id = "team-a" + mock_auth_team_a.user_role = None + + result = responses_id_security.check_user_access_to_response_id( + response_id_user_id=None, + response_id_team_id="team-a", + user_api_key_dict=mock_auth_team_a, + ) + + assert result is True + + def test_check_user_access_admin_can_access_any_response( + self, responses_id_security + ): + """Test that proxy admin can access any response ID""" + from litellm.proxy._types import LitellmUserRoles + + # Create a mock admin user + mock_admin_auth = MagicMock() + mock_admin_auth.user_id = "admin-user" + mock_admin_auth.team_id = "admin-team" + mock_admin_auth.user_role = LitellmUserRoles.PROXY_ADMIN.value + + # Admin should be able to access response from different team and different user + result = responses_id_security.check_user_access_to_response_id( + response_id_user_id="some-other-user", + response_id_team_id="some-other-team", + user_api_key_dict=mock_admin_auth, + ) + + assert result is True + + def test_check_user_access_security_disabled( + self, responses_id_security, mock_user_api_key_dict + ): + """Test that when security is disabled, any user can access any response""" + with patch( + "litellm.proxy.proxy_server.general_settings", + {"disable_responses_id_security": True}, + ): + # User from team A should be able to access response from team B when security is disabled + result = responses_id_security.check_user_access_to_response_id( + response_id_user_id="different-user", + response_id_team_id="different-team", + user_api_key_dict=mock_user_api_key_dict, + ) + + assert result is True + + +class TestAsyncPreCallHook: + """Test async_pre_call_hook function""" + + @pytest.mark.asyncio + async def test_async_pre_call_hook_aresponses_with_previous_response_id( + self, responses_id_security, mock_user_api_key_dict, mock_cache + ): + """Test pre-call hook decrypts previous_response_id for aresponses call""" + data = {"previous_response_id": "resp_encrypted_value"} + + with patch.object( + responses_id_security, "_is_encrypted_response_id", return_value=True + ): + with patch.object( + responses_id_security, + "_decrypt_response_id", + return_value=("resp_original_123", "test-user-123", "test-team-123"), + ): + result = await responses_id_security.async_pre_call_hook( + user_api_key_dict=mock_user_api_key_dict, + cache=mock_cache, + data=data, + call_type="aresponses", + ) + + assert result["previous_response_id"] == "resp_original_123" + + @pytest.mark.asyncio + async def test_async_pre_call_hook_aget_responses( + self, responses_id_security, mock_user_api_key_dict, mock_cache + ): + """Test pre-call hook decrypts response_id for aget_responses call""" + data = {"response_id": "resp_encrypted_456"} + + with patch.object( + responses_id_security, "_is_encrypted_response_id", return_value=True + ): + with patch.object( + responses_id_security, + "_decrypt_response_id", + return_value=("resp_original_456", "test-user-123", "test-team-123"), + ): + result = await responses_id_security.async_pre_call_hook( + user_api_key_dict=mock_user_api_key_dict, + cache=mock_cache, + data=data, + call_type="aget_responses", + ) + + assert result["response_id"] == "resp_original_456" + + @pytest.mark.asyncio + async def test_async_pre_call_hook_team_a_accessing_team_b_response( + self, responses_id_security, mock_cache + ): + """Test pre-call hook prevents team A from accessing team B response""" + # Create a mock user from team A + mock_auth_team_a = MagicMock() + mock_auth_team_a.user_id = None + mock_auth_team_a.team_id = "team-a" + mock_auth_team_a.user_role = None + + data = {"response_id": "resp_encrypted_team_b"} + + with patch.object( + responses_id_security, "_is_encrypted_response_id", return_value=True + ): + with patch.object( + responses_id_security, + "_decrypt_response_id", + return_value=("resp_original_team_b", None, "team-b"), + ): + with patch("litellm.proxy.proxy_server.general_settings", {}): + with pytest.raises(HTTPException) as exc_info: + await responses_id_security.async_pre_call_hook( + user_api_key_dict=mock_auth_team_a, + cache=mock_cache, + data=data, + call_type="aget_responses", + ) + + assert exc_info.value.status_code == 403 + assert "team" in exc_info.value.detail.lower() + + @pytest.mark.asyncio + async def test_async_pre_call_hook_team_a_accessing_team_b_with_user( + self, responses_id_security, mock_cache + ): + """Test pre-call hook prevents team A (no user) from accessing team B response (with user)""" + # Create a mock user from team A without user_id + mock_auth_team_a = MagicMock() + mock_auth_team_a.user_id = None + mock_auth_team_a.team_id = "team-a" + mock_auth_team_a.user_role = None + + data = {"response_id": "resp_encrypted_team_b_with_user"} + + with patch.object( + responses_id_security, "_is_encrypted_response_id", return_value=True + ): + with patch.object( + responses_id_security, + "_decrypt_response_id", + return_value=("resp_original_team_b", "user-from-team-b", "team-b"), + ): + with patch("litellm.proxy.proxy_server.general_settings", {}): + with pytest.raises(HTTPException) as exc_info: + await responses_id_security.async_pre_call_hook( + user_api_key_dict=mock_auth_team_a, + cache=mock_cache, + data=data, + call_type="aget_responses", + ) + + # Access should be denied with 403. Could fail on user_id or team_id check. + assert exc_info.value.status_code == 403 + assert "forbidden" in exc_info.value.detail.lower() + + @pytest.mark.asyncio + async def test_async_pre_call_hook_same_team_access( + self, responses_id_security, mock_cache + ): + """Test pre-call hook allows team A to access their own team's response""" + # Create a mock user from team A + mock_auth_team_a = MagicMock() + mock_auth_team_a.user_id = None + mock_auth_team_a.team_id = "team-a" + mock_auth_team_a.user_role = None + + data = {"response_id": "resp_encrypted_team_a"} + + with patch.object( + responses_id_security, "_is_encrypted_response_id", return_value=True + ): + with patch.object( + responses_id_security, + "_decrypt_response_id", + return_value=("resp_original_team_a", None, "team-a"), + ): + result = await responses_id_security.async_pre_call_hook( + user_api_key_dict=mock_auth_team_a, + cache=mock_cache, + data=data, + call_type="aget_responses", + ) + + assert result["response_id"] == "resp_original_team_a" + + @pytest.mark.asyncio + async def test_async_pre_call_hook_adelete_responses_team_security( + self, responses_id_security, mock_cache + ): + """Test pre-call hook prevents team A from deleting team B's response""" + # Create a mock user from team A + mock_auth_team_a = MagicMock() + mock_auth_team_a.user_id = None + mock_auth_team_a.team_id = "team-a" + mock_auth_team_a.user_role = None + + data = {"response_id": "resp_encrypted_team_b"} + + with patch.object( + responses_id_security, "_is_encrypted_response_id", return_value=True + ): + with patch.object( + responses_id_security, + "_decrypt_response_id", + return_value=("resp_original_team_b", None, "team-b"), + ): + with patch("litellm.proxy.proxy_server.general_settings", {}): + with pytest.raises(HTTPException) as exc_info: + await responses_id_security.async_pre_call_hook( + user_api_key_dict=mock_auth_team_a, + cache=mock_cache, + data=data, + call_type="adelete_responses", + ) + + assert exc_info.value.status_code == 403 + assert "team" in exc_info.value.detail.lower() + + @pytest.mark.asyncio + async def test_async_pre_call_hook_acancel_responses_team_security( + self, responses_id_security, mock_cache + ): + """Test pre-call hook prevents team A from canceling team B's response""" + # Create a mock user from team A + mock_auth_team_a = MagicMock() + mock_auth_team_a.user_id = None + mock_auth_team_a.team_id = "team-a" + mock_auth_team_a.user_role = None + + data = {"response_id": "resp_encrypted_team_b"} + + with patch.object( + responses_id_security, "_is_encrypted_response_id", return_value=True + ): + with patch.object( + responses_id_security, + "_decrypt_response_id", + return_value=("resp_original_team_b", None, "team-b"), + ): + with patch("litellm.proxy.proxy_server.general_settings", {}): + with pytest.raises(HTTPException) as exc_info: + await responses_id_security.async_pre_call_hook( + user_api_key_dict=mock_auth_team_a, + cache=mock_cache, + data=data, + call_type="acancel_responses", + ) + + assert exc_info.value.status_code == 403 + assert "team" in exc_info.value.detail.lower() + + +class TestAsyncPostCallSuccessHook: + """Test async_post_call_success_hook function""" + + @pytest.mark.asyncio + async def test_async_post_call_success_hook_encrypts_response( + self, responses_id_security, mock_user_api_key_dict + ): + """Test post-call hook encrypts ResponsesAPIResponse""" + mock_response = ResponsesAPIResponse( + id="resp_789", created_at=1234567890, output=[], status="completed" + ) + data = {} + + with patch.object( + responses_id_security, "_encrypt_response_id", return_value=mock_response + ) as mock_encrypt: + result = await responses_id_security.async_post_call_success_hook( + data=data, + user_api_key_dict=mock_user_api_key_dict, + response=mock_response, + ) + + mock_encrypt.assert_called_once_with(mock_response, mock_user_api_key_dict) + assert result == mock_response + + @pytest.mark.asyncio + async def test_async_post_call_success_hook_non_responses_api_response( + self, responses_id_security, mock_user_api_key_dict + ): + """Test post-call hook passes through non-ResponsesAPIResponse objects""" + mock_response = {"id": "some-other-response", "data": "test"} + data = {} + + result = await responses_id_security.async_post_call_success_hook( + data=data, + user_api_key_dict=mock_user_api_key_dict, + response=mock_response, + ) + + assert result == mock_response