|
4 | 4 | # This source code is licensed under the terms described in the LICENSE file in |
5 | 5 | # the root directory of this source tree. |
6 | 6 |
|
7 | | -import json |
8 | | -from collections.abc import AsyncIterator |
| 7 | +from collections.abc import AsyncIterator, Iterable |
9 | 8 | from typing import Any |
10 | 9 |
|
11 | | -from botocore.client import BaseClient |
| 10 | +from openai import AuthenticationError, BadRequestError, NotFoundError |
12 | 11 |
|
13 | 12 | from llama_stack.apis.inference import ( |
14 | | - ChatCompletionRequest, |
15 | | - Inference, |
16 | | - OpenAIEmbeddingsResponse, |
17 | | -) |
18 | | -from llama_stack.apis.inference.inference import ( |
| 13 | + Model, |
19 | 14 | OpenAIChatCompletion, |
20 | 15 | OpenAIChatCompletionChunk, |
21 | | - OpenAICompletion, |
22 | 16 | OpenAIMessageParam, |
23 | 17 | OpenAIResponseFormatParam, |
24 | 18 | ) |
25 | | -from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig |
26 | | -from llama_stack.providers.utils.bedrock.client import create_bedrock_client |
27 | | -from llama_stack.providers.utils.inference.model_registry import ( |
28 | | - ModelRegistryHelper, |
29 | | -) |
30 | | -from llama_stack.providers.utils.inference.openai_compat import ( |
31 | | - get_sampling_strategy_options, |
32 | | -) |
33 | | -from llama_stack.providers.utils.inference.prompt_adapter import ( |
34 | | - chat_completion_request_to_prompt, |
35 | | -) |
36 | | - |
37 | | -from .models import MODEL_ENTRIES |
38 | | - |
39 | | -REGION_PREFIX_MAP = { |
40 | | - "us": "us.", |
41 | | - "eu": "eu.", |
42 | | - "ap": "ap.", |
43 | | -} |
44 | | - |
45 | | - |
46 | | -def _get_region_prefix(region: str | None) -> str: |
47 | | - # AWS requires region prefixes for inference profiles |
48 | | - if region is None: |
49 | | - return "us." # default to US when we don't know |
50 | | - |
51 | | - # Handle case insensitive region matching |
52 | | - region_lower = region.lower() |
53 | | - for prefix in REGION_PREFIX_MAP: |
54 | | - if region_lower.startswith(f"{prefix}-"): |
55 | | - return REGION_PREFIX_MAP[prefix] |
| 19 | +from llama_stack.log import get_logger |
| 20 | +from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params |
| 21 | +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin |
| 22 | +from llama_stack.providers.utils.telemetry.tracing import get_current_span |
56 | 23 |
|
57 | | - # Fallback to US for anything we don't recognize |
58 | | - return "us." |
| 24 | +from .config import BedrockConfig |
59 | 25 |
|
| 26 | +logger = get_logger(name=__name__, category="inference::bedrock") |
60 | 27 |
|
61 | | -def _to_inference_profile_id(model_id: str, region: str = None) -> str: |
62 | | - # Return ARNs unchanged |
63 | | - if model_id.startswith("arn:"): |
64 | | - return model_id |
65 | 28 |
|
66 | | - # Return inference profile IDs that already have regional prefixes |
67 | | - if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()): |
68 | | - return model_id |
| 29 | +class BedrockInferenceAdapter(OpenAIMixin): |
| 30 | + """ |
| 31 | + Adapter for AWS Bedrock's OpenAI-compatible API endpoints. |
69 | 32 |
|
70 | | - # Default to US East when no region is provided |
71 | | - if region is None: |
72 | | - region = "us-east-1" |
| 33 | + Supports Llama models across regions and GPT-OSS models (us-west-2 only). |
73 | 34 |
|
74 | | - return _get_region_prefix(region) + model_id |
| 35 | + Note: Bedrock's OpenAI-compatible endpoint does not support /v1/models |
| 36 | + for dynamic model discovery. Models must be pre-registered in the config. |
| 37 | + """ |
75 | 38 |
|
| 39 | + config: BedrockConfig |
| 40 | + provider_data_api_key_field: str = "aws_bedrock_api_key" |
76 | 41 |
|
77 | | -class BedrockInferenceAdapter( |
78 | | - ModelRegistryHelper, |
79 | | - Inference, |
80 | | -): |
81 | | - def __init__(self, config: BedrockConfig) -> None: |
82 | | - ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) |
83 | | - self._config = config |
84 | | - self._client = None |
| 42 | + def get_api_key(self) -> str: |
| 43 | + """Get API key for OpenAI client.""" |
| 44 | + if not self.config.api_key: |
| 45 | + raise ValueError( |
| 46 | + "API key is not set. Please provide a valid API key in the " |
| 47 | + "provider config or via AWS_BEDROCK_API_KEY environment variable." |
| 48 | + ) |
| 49 | + return self.config.api_key |
85 | 50 |
|
86 | | - @property |
87 | | - def client(self) -> BaseClient: |
88 | | - if self._client is None: |
89 | | - self._client = create_bedrock_client(self._config) |
90 | | - return self._client |
| 51 | + def get_base_url(self) -> str: |
| 52 | + """Get base URL for OpenAI client.""" |
| 53 | + return f"https://bedrock-runtime.{self.config.region_name}.amazonaws.com/openai/v1" |
91 | 54 |
|
92 | | - async def initialize(self) -> None: |
93 | | - pass |
| 55 | + async def list_provider_model_ids(self) -> Iterable[str]: |
| 56 | + """ |
| 57 | + Bedrock's OpenAI-compatible endpoint does not support the /v1/models endpoint. |
| 58 | + Returns empty list since models must be pre-registered in the config. |
| 59 | + """ |
| 60 | + return [] |
94 | 61 |
|
95 | | - async def shutdown(self) -> None: |
96 | | - if self._client is not None: |
97 | | - self._client.close() |
98 | | - |
99 | | - async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict: |
100 | | - bedrock_model = request.model |
101 | | - |
102 | | - sampling_params = request.sampling_params |
103 | | - options = get_sampling_strategy_options(sampling_params) |
104 | | - |
105 | | - if sampling_params.max_tokens: |
106 | | - options["max_gen_len"] = sampling_params.max_tokens |
107 | | - if sampling_params.repetition_penalty > 0: |
108 | | - options["repetition_penalty"] = sampling_params.repetition_penalty |
109 | | - |
110 | | - prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model)) |
111 | | - |
112 | | - # Convert foundation model ID to inference profile ID |
113 | | - region_name = self.client.meta.region_name |
114 | | - inference_profile_id = _to_inference_profile_id(bedrock_model, region_name) |
115 | | - |
116 | | - return { |
117 | | - "modelId": inference_profile_id, |
118 | | - "body": json.dumps( |
119 | | - { |
120 | | - "prompt": prompt, |
121 | | - **options, |
122 | | - } |
123 | | - ), |
124 | | - } |
125 | | - |
126 | | - async def openai_embeddings( |
127 | | - self, |
128 | | - model: str, |
129 | | - input: str | list[str], |
130 | | - encoding_format: str | None = "float", |
131 | | - dimensions: int | None = None, |
132 | | - user: str | None = None, |
133 | | - ) -> OpenAIEmbeddingsResponse: |
134 | | - raise NotImplementedError() |
| 62 | + async def register_model(self, model: Model) -> Model: |
| 63 | + """ |
| 64 | + Register a model with the Bedrock provider. |
135 | 65 |
|
136 | | - async def openai_completion( |
137 | | - self, |
138 | | - # Standard OpenAI completion parameters |
139 | | - model: str, |
140 | | - prompt: str | list[str] | list[int] | list[list[int]], |
141 | | - best_of: int | None = None, |
142 | | - echo: bool | None = None, |
143 | | - frequency_penalty: float | None = None, |
144 | | - logit_bias: dict[str, float] | None = None, |
145 | | - logprobs: bool | None = None, |
146 | | - max_tokens: int | None = None, |
147 | | - n: int | None = None, |
148 | | - presence_penalty: float | None = None, |
149 | | - seed: int | None = None, |
150 | | - stop: str | list[str] | None = None, |
151 | | - stream: bool | None = None, |
152 | | - stream_options: dict[str, Any] | None = None, |
153 | | - temperature: float | None = None, |
154 | | - top_p: float | None = None, |
155 | | - user: str | None = None, |
156 | | - # vLLM-specific parameters |
157 | | - guided_choice: list[str] | None = None, |
158 | | - prompt_logprobs: int | None = None, |
159 | | - # for fill-in-the-middle type completion |
160 | | - suffix: str | None = None, |
161 | | - ) -> OpenAICompletion: |
162 | | - raise NotImplementedError("OpenAI completion not supported by the Bedrock provider") |
| 66 | + Bedrock doesn't support dynamic model listing via /v1/models, so we skip |
| 67 | + the availability check and accept all models registered in the config. |
| 68 | + """ |
| 69 | + return model |
163 | 70 |
|
164 | 71 | async def openai_chat_completion( |
165 | 72 | self, |
@@ -187,4 +94,67 @@ async def openai_chat_completion( |
187 | 94 | top_p: float | None = None, |
188 | 95 | user: str | None = None, |
189 | 96 | ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: |
190 | | - raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider") |
| 97 | + """Override to add Bedrock-specific model ID handling and streaming options.""" |
| 98 | + # Get the provider model ID from model store |
| 99 | + model_obj = await self.model_store.get_model(model) # type: ignore[attr-defined] |
| 100 | + provider_model_id: str = model_obj.provider_resource_id or model |
| 101 | + |
| 102 | + # Bedrock OpenAI-compatible endpoint expects base model IDs (e.g. "openai.gpt-oss-20b-1:0"). |
| 103 | + # Cross-region inference profile IDs prefixed with "us." are not recognized by the endpoint. |
| 104 | + # Normalize to base model ID, then try both base and prefixed forms for compatibility. |
| 105 | + base_model_id = provider_model_id[3:] if provider_model_id.startswith("us.") else provider_model_id |
| 106 | + candidate_models = [base_model_id, f"us.{base_model_id}"] |
| 107 | + |
| 108 | + # Enable streaming usage metrics when telemetry is active |
| 109 | + if stream and get_current_span() is not None: |
| 110 | + if stream_options is None: |
| 111 | + stream_options = {"include_usage": True} |
| 112 | + elif "include_usage" not in stream_options: |
| 113 | + stream_options = {**stream_options, "include_usage": True} |
| 114 | + |
| 115 | + # Try candidate model IDs with retry logic |
| 116 | + last_error: Exception | None = None |
| 117 | + for candidate in candidate_models: |
| 118 | + try: |
| 119 | + logger.debug(f"Attempting request with model ID: {candidate}") |
| 120 | + # Call OpenAI client directly with the candidate model ID |
| 121 | + # We can't use super().openai_chat_completion() because it would |
| 122 | + # call self._get_provider_model_id() which looks up the model again |
| 123 | + params = await prepare_openai_completion_params( |
| 124 | + model=candidate, |
| 125 | + messages=messages, |
| 126 | + frequency_penalty=frequency_penalty, |
| 127 | + function_call=function_call, |
| 128 | + functions=functions, |
| 129 | + logit_bias=logit_bias, |
| 130 | + logprobs=logprobs, |
| 131 | + max_completion_tokens=max_completion_tokens, |
| 132 | + max_tokens=max_tokens, |
| 133 | + n=n, |
| 134 | + parallel_tool_calls=parallel_tool_calls, |
| 135 | + presence_penalty=presence_penalty, |
| 136 | + response_format=response_format, |
| 137 | + seed=seed, |
| 138 | + stop=stop, |
| 139 | + stream=stream, |
| 140 | + stream_options=stream_options, |
| 141 | + temperature=temperature, |
| 142 | + tool_choice=tool_choice, |
| 143 | + tools=tools, |
| 144 | + top_logprobs=top_logprobs, |
| 145 | + top_p=top_p, |
| 146 | + user=user, |
| 147 | + ) |
| 148 | + resp = await self.client.chat.completions.create(**params) |
| 149 | + return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return] |
| 150 | + except AuthenticationError as e: |
| 151 | + # Authentication errors - no retry with different model IDs |
| 152 | + raise ValueError(f"Authentication failed: {str(e)}") from e |
| 153 | + except (NotFoundError, BadRequestError) as e: |
| 154 | + logger.debug(f"Model ID {candidate} failed: {e}, trying next candidate") |
| 155 | + last_error = e |
| 156 | + continue |
| 157 | + |
| 158 | + if last_error: |
| 159 | + raise last_error |
| 160 | + raise RuntimeError("Bedrock chat completion failed") |
0 commit comments