|
4 | 4 | # This source code is licensed under the terms described in the LICENSE file in |
5 | 5 | # the root directory of this source tree. |
6 | 6 |
|
7 | | -import json |
8 | | -from collections.abc import AsyncIterator |
| 7 | +from collections.abc import AsyncIterator, Iterable |
9 | 8 |
|
10 | | -from botocore.client import BaseClient |
| 9 | +from openai import AuthenticationError |
11 | 10 |
|
12 | 11 | from llama_stack.apis.inference import ( |
13 | | - ChatCompletionRequest, |
14 | | - Inference, |
| 12 | + Model, |
| 13 | + OpenAIChatCompletion, |
| 14 | + OpenAIChatCompletionChunk, |
15 | 15 | OpenAIChatCompletionRequestWithExtraBody, |
| 16 | + OpenAICompletion, |
16 | 17 | OpenAICompletionRequestWithExtraBody, |
17 | 18 | OpenAIEmbeddingsRequestWithExtraBody, |
18 | 19 | OpenAIEmbeddingsResponse, |
19 | 20 | ) |
20 | | -from llama_stack.apis.inference.inference import ( |
21 | | - OpenAIChatCompletion, |
22 | | - OpenAIChatCompletionChunk, |
23 | | - OpenAICompletion, |
24 | | -) |
25 | | -from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig |
26 | | -from llama_stack.providers.utils.bedrock.client import create_bedrock_client |
27 | | -from llama_stack.providers.utils.inference.model_registry import ( |
28 | | - ModelRegistryHelper, |
29 | | -) |
30 | | -from llama_stack.providers.utils.inference.openai_compat import ( |
31 | | - get_sampling_strategy_options, |
32 | | -) |
33 | | -from llama_stack.providers.utils.inference.prompt_adapter import ( |
34 | | - chat_completion_request_to_prompt, |
35 | | -) |
36 | | - |
37 | | -from .models import MODEL_ENTRIES |
38 | | - |
39 | | -REGION_PREFIX_MAP = { |
40 | | - "us": "us.", |
41 | | - "eu": "eu.", |
42 | | - "ap": "ap.", |
43 | | -} |
44 | | - |
45 | | - |
46 | | -def _get_region_prefix(region: str | None) -> str: |
47 | | - # AWS requires region prefixes for inference profiles |
48 | | - if region is None: |
49 | | - return "us." # default to US when we don't know |
50 | | - |
51 | | - # Handle case insensitive region matching |
52 | | - region_lower = region.lower() |
53 | | - for prefix in REGION_PREFIX_MAP: |
54 | | - if region_lower.startswith(f"{prefix}-"): |
55 | | - return REGION_PREFIX_MAP[prefix] |
56 | | - |
57 | | - # Fallback to US for anything we don't recognize |
58 | | - return "us." |
59 | | - |
60 | | - |
61 | | -def _to_inference_profile_id(model_id: str, region: str = None) -> str: |
62 | | - # Return ARNs unchanged |
63 | | - if model_id.startswith("arn:"): |
64 | | - return model_id |
65 | | - |
66 | | - # Return inference profile IDs that already have regional prefixes |
67 | | - if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()): |
68 | | - return model_id |
69 | | - |
70 | | - # Default to US East when no region is provided |
71 | | - if region is None: |
72 | | - region = "us-east-1" |
73 | | - |
74 | | - return _get_region_prefix(region) + model_id |
| 21 | +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin |
| 22 | +from llama_stack.providers.utils.telemetry.tracing import get_current_span |
75 | 23 |
|
| 24 | +from .config import BedrockConfig |
76 | 25 |
|
77 | | -class BedrockInferenceAdapter( |
78 | | - ModelRegistryHelper, |
79 | | - Inference, |
80 | | -): |
81 | | - def __init__(self, config: BedrockConfig) -> None: |
82 | | - ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) |
83 | | - self._config = config |
84 | | - self._client = None |
85 | 26 |
|
86 | | - @property |
87 | | - def client(self) -> BaseClient: |
88 | | - if self._client is None: |
89 | | - self._client = create_bedrock_client(self._config) |
90 | | - return self._client |
| 27 | +class BedrockInferenceAdapter(OpenAIMixin): |
| 28 | + """ |
| 29 | + Adapter for AWS Bedrock's OpenAI-compatible API endpoints. |
91 | 30 |
|
92 | | - async def initialize(self) -> None: |
93 | | - pass |
| 31 | + Supports Llama models across regions and GPT-OSS models (us-west-2 only). |
94 | 32 |
|
95 | | - async def shutdown(self) -> None: |
96 | | - if self._client is not None: |
97 | | - self._client.close() |
| 33 | + Note: Bedrock's OpenAI-compatible endpoint does not support /v1/models |
| 34 | + for dynamic model discovery. Models must be pre-registered in the config. |
| 35 | + """ |
98 | 36 |
|
99 | | - async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict: |
100 | | - bedrock_model = request.model |
| 37 | + config: BedrockConfig |
| 38 | + provider_data_api_key_field: str = "aws_bedrock_api_key" |
101 | 39 |
|
102 | | - sampling_params = request.sampling_params |
103 | | - options = get_sampling_strategy_options(sampling_params) |
| 40 | + def get_api_key(self) -> str: |
| 41 | + """Get API key for OpenAI client.""" |
| 42 | + if not self.config.api_key: |
| 43 | + raise ValueError( |
| 44 | + "API key is not set. Please provide a valid API key in the " |
| 45 | + "provider config or via AWS_BEDROCK_API_KEY environment variable." |
| 46 | + ) |
| 47 | + return self.config.api_key |
104 | 48 |
|
105 | | - if sampling_params.max_tokens: |
106 | | - options["max_gen_len"] = sampling_params.max_tokens |
107 | | - if sampling_params.repetition_penalty > 0: |
108 | | - options["repetition_penalty"] = sampling_params.repetition_penalty |
| 49 | + def get_base_url(self) -> str: |
| 50 | + """Get base URL for OpenAI client.""" |
| 51 | + return f"https://bedrock-runtime.{self.config.region_name}.amazonaws.com/openai/v1" |
109 | 52 |
|
110 | | - prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model)) |
| 53 | + async def list_provider_model_ids(self) -> Iterable[str]: |
| 54 | + """ |
| 55 | + Bedrock's OpenAI-compatible endpoint does not support the /v1/models endpoint. |
| 56 | + Returns empty list since models must be pre-registered in the config. |
| 57 | + """ |
| 58 | + return [] |
111 | 59 |
|
112 | | - # Convert foundation model ID to inference profile ID |
113 | | - region_name = self.client.meta.region_name |
114 | | - inference_profile_id = _to_inference_profile_id(bedrock_model, region_name) |
| 60 | + async def register_model(self, model: Model) -> Model: |
| 61 | + """ |
| 62 | + Register a model with the Bedrock provider. |
115 | 63 |
|
116 | | - return { |
117 | | - "modelId": inference_profile_id, |
118 | | - "body": json.dumps( |
119 | | - { |
120 | | - "prompt": prompt, |
121 | | - **options, |
122 | | - } |
123 | | - ), |
124 | | - } |
| 64 | + Bedrock doesn't support dynamic model listing via /v1/models, so we skip |
| 65 | + the availability check and accept all models registered in the config. |
| 66 | + """ |
| 67 | + return model |
125 | 68 |
|
126 | 69 | async def openai_embeddings( |
127 | 70 | self, |
128 | 71 | params: OpenAIEmbeddingsRequestWithExtraBody, |
129 | 72 | ) -> OpenAIEmbeddingsResponse: |
130 | | - raise NotImplementedError() |
| 73 | + """Bedrock's OpenAI-compatible API does not support the /v1/embeddings endpoint.""" |
| 74 | + raise NotImplementedError( |
| 75 | + "Bedrock's OpenAI-compatible API does not support /v1/embeddings endpoint. " |
| 76 | + "See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html" |
| 77 | + ) |
131 | 78 |
|
132 | 79 | async def openai_completion( |
133 | 80 | self, |
134 | 81 | params: OpenAICompletionRequestWithExtraBody, |
135 | 82 | ) -> OpenAICompletion: |
136 | | - raise NotImplementedError("OpenAI completion not supported by the Bedrock provider") |
| 83 | + """Bedrock's OpenAI-compatible API does not support the /v1/completions endpoint.""" |
| 84 | + raise NotImplementedError( |
| 85 | + "Bedrock's OpenAI-compatible API does not support /v1/completions endpoint. " |
| 86 | + "Only /v1/chat/completions is supported. " |
| 87 | + "See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html" |
| 88 | + ) |
137 | 89 |
|
138 | 90 | async def openai_chat_completion( |
139 | 91 | self, |
140 | 92 | params: OpenAIChatCompletionRequestWithExtraBody, |
141 | 93 | ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: |
142 | | - raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider") |
| 94 | + """Override to enable streaming usage metrics and handle authentication errors.""" |
| 95 | + # Enable streaming usage metrics when telemetry is active |
| 96 | + if params.stream and get_current_span() is not None: |
| 97 | + if params.stream_options is None: |
| 98 | + params.stream_options = {"include_usage": True} |
| 99 | + elif "include_usage" not in params.stream_options: |
| 100 | + params.stream_options = {**params.stream_options, "include_usage": True} |
| 101 | + |
| 102 | + # Wrap call in try/except to catch authentication errors |
| 103 | + try: |
| 104 | + return await super().openai_chat_completion(params=params) |
| 105 | + except AuthenticationError as e: |
| 106 | + raise ValueError( |
| 107 | + f"AWS Bedrock authentication failed: {e.message}. " |
| 108 | + "Please check your API key in the provider config or x-llamastack-provider-data header." |
| 109 | + ) from e |
0 commit comments