Skip to content

Commit 56bff11

Browse files
committed
feat: add OpenAI-compatible Bedrock provider with error handling
Implements AWS Bedrock inference provider using OpenAI-compatible endpoint for Llama models available through Bedrock. Changes: - Add BedrockInferenceAdapter using OpenAIMixin base - Configure region-specific endpoint URLs - Add NotImplementedError stubs for unsupported endpoints - Implement authentication error handling with helpful messages - Remove unused models.py file - Add comprehensive unit tests (12 total) - Add provider registry configuration
1 parent 96886af commit 56bff11

File tree

12 files changed

+300
-186
lines changed

12 files changed

+300
-186
lines changed
Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
---
2-
description: "AWS Bedrock inference provider for accessing various AI models through AWS's managed service."
2+
description: "AWS Bedrock inference provider using OpenAI compatible endpoint."
33
sidebar_label: Remote - Bedrock
44
title: remote::bedrock
55
---
@@ -8,27 +8,20 @@ title: remote::bedrock
88

99
## Description
1010

11-
AWS Bedrock inference provider for accessing various AI models through AWS's managed service.
11+
AWS Bedrock inference provider using OpenAI compatible endpoint.
1212

1313
## Configuration
1414

1515
| Field | Type | Required | Default | Description |
1616
|-------|------|----------|---------|-------------|
1717
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
1818
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
19-
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
20-
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
21-
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
22-
| `region_name` | `str \| None` | No | | The default AWS Region to use, for example, us-west-1 or us-west-2.Default use environment variable: AWS_DEFAULT_REGION |
23-
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
24-
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
25-
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
26-
| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
27-
| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
28-
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
19+
| `api_key` | `str \| None` | No | | Amazon Bedrock API key |
20+
| `region_name` | `<class 'str'>` | No | us-east-2 | AWS Region for the Bedrock Runtime endpoint |
2921

3022
## Sample Configuration
3123

3224
```yaml
33-
{}
25+
api_key: ${env.AWS_BEDROCK_API_KEY:=}
26+
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
3427
```

llama_stack/distributions/ci-tests/run.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ providers:
4747
api_key: ${env.TOGETHER_API_KEY:=}
4848
- provider_id: bedrock
4949
provider_type: remote::bedrock
50+
config:
51+
api_key: ${env.AWS_BEDROCK_API_KEY:=}
52+
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
5053
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
5154
provider_type: remote::nvidia
5255
config:

llama_stack/distributions/starter-gpu/run.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ providers:
4747
api_key: ${env.TOGETHER_API_KEY:=}
4848
- provider_id: bedrock
4949
provider_type: remote::bedrock
50+
config:
51+
api_key: ${env.AWS_BEDROCK_API_KEY:=}
52+
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
5053
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
5154
provider_type: remote::nvidia
5255
config:

llama_stack/distributions/starter/run.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ providers:
4747
api_key: ${env.TOGETHER_API_KEY:=}
4848
- provider_id: bedrock
4949
provider_type: remote::bedrock
50+
config:
51+
api_key: ${env.AWS_BEDROCK_API_KEY:=}
52+
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
5053
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
5154
provider_type: remote::nvidia
5255
config:

llama_stack/providers/registry/inference.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,11 @@ def available_providers() -> list[ProviderSpec]:
131131
api=Api.inference,
132132
adapter_type="bedrock",
133133
provider_type="remote::bedrock",
134-
pip_packages=["boto3"],
134+
pip_packages=[],
135135
module="llama_stack.providers.remote.inference.bedrock",
136136
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
137-
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
137+
provider_data_validator="llama_stack.providers.remote.inference.bedrock.config.BedrockProviderDataValidator",
138+
description="AWS Bedrock inference provider using OpenAI compatible endpoint.",
138139
),
139140
RemoteProviderSpec(
140141
api=Api.inference,

llama_stack/providers/remote/inference/bedrock/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ async def get_adapter_impl(config: BedrockConfig, _deps):
1111

1212
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
1313

14-
impl = BedrockInferenceAdapter(config)
14+
impl = BedrockInferenceAdapter(config=config)
1515

1616
await impl.initialize()
1717

llama_stack/providers/remote/inference/bedrock/bedrock.py

Lines changed: 91 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -4,124 +4,67 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
import json
8-
from collections.abc import AsyncIterator
7+
from collections.abc import AsyncIterator, Iterable
98
from typing import Any
109

11-
from botocore.client import BaseClient
10+
from openai import AuthenticationError
1211

1312
from llama_stack.apis.inference import (
14-
ChatCompletionRequest,
15-
Inference,
16-
OpenAIEmbeddingsResponse,
17-
)
18-
from llama_stack.apis.inference.inference import (
13+
Model,
1914
OpenAIChatCompletion,
2015
OpenAIChatCompletionChunk,
2116
OpenAICompletion,
17+
OpenAIEmbeddingsResponse,
2218
OpenAIMessageParam,
2319
OpenAIResponseFormatParam,
2420
)
25-
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
26-
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
27-
from llama_stack.providers.utils.inference.model_registry import (
28-
ModelRegistryHelper,
29-
)
30-
from llama_stack.providers.utils.inference.openai_compat import (
31-
get_sampling_strategy_options,
32-
)
33-
from llama_stack.providers.utils.inference.prompt_adapter import (
34-
chat_completion_request_to_prompt,
35-
)
36-
37-
from .models import MODEL_ENTRIES
38-
39-
REGION_PREFIX_MAP = {
40-
"us": "us.",
41-
"eu": "eu.",
42-
"ap": "ap.",
43-
}
44-
45-
46-
def _get_region_prefix(region: str | None) -> str:
47-
# AWS requires region prefixes for inference profiles
48-
if region is None:
49-
return "us." # default to US when we don't know
50-
51-
# Handle case insensitive region matching
52-
region_lower = region.lower()
53-
for prefix in REGION_PREFIX_MAP:
54-
if region_lower.startswith(f"{prefix}-"):
55-
return REGION_PREFIX_MAP[prefix]
56-
57-
# Fallback to US for anything we don't recognize
58-
return "us."
59-
60-
61-
def _to_inference_profile_id(model_id: str, region: str = None) -> str:
62-
# Return ARNs unchanged
63-
if model_id.startswith("arn:"):
64-
return model_id
65-
66-
# Return inference profile IDs that already have regional prefixes
67-
if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()):
68-
return model_id
69-
70-
# Default to US East when no region is provided
71-
if region is None:
72-
region = "us-east-1"
73-
74-
return _get_region_prefix(region) + model_id
21+
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
22+
from llama_stack.providers.utils.telemetry.tracing import get_current_span
7523

24+
from .config import BedrockConfig
7625

77-
class BedrockInferenceAdapter(
78-
ModelRegistryHelper,
79-
Inference,
80-
):
81-
def __init__(self, config: BedrockConfig) -> None:
82-
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
83-
self._config = config
84-
self._client = None
8526

86-
@property
87-
def client(self) -> BaseClient:
88-
if self._client is None:
89-
self._client = create_bedrock_client(self._config)
90-
return self._client
27+
class BedrockInferenceAdapter(OpenAIMixin):
28+
"""
29+
Adapter for AWS Bedrock's OpenAI-compatible API endpoints.
9130
92-
async def initialize(self) -> None:
93-
pass
31+
Supports Llama models across regions and GPT-OSS models (us-west-2 only).
9432
95-
async def shutdown(self) -> None:
96-
if self._client is not None:
97-
self._client.close()
33+
Note: Bedrock's OpenAI-compatible endpoint does not support /v1/models
34+
for dynamic model discovery. Models must be pre-registered in the config.
35+
"""
9836

99-
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict:
100-
bedrock_model = request.model
37+
config: BedrockConfig
38+
provider_data_api_key_field: str = "aws_bedrock_api_key"
10139

102-
sampling_params = request.sampling_params
103-
options = get_sampling_strategy_options(sampling_params)
40+
def get_api_key(self) -> str:
41+
"""Get API key for OpenAI client."""
42+
if not self.config.api_key:
43+
raise ValueError(
44+
"API key is not set. Please provide a valid API key in the "
45+
"provider config or via AWS_BEDROCK_API_KEY environment variable."
46+
)
47+
return self.config.api_key
10448

105-
if sampling_params.max_tokens:
106-
options["max_gen_len"] = sampling_params.max_tokens
107-
if sampling_params.repetition_penalty > 0:
108-
options["repetition_penalty"] = sampling_params.repetition_penalty
49+
def get_base_url(self) -> str:
50+
"""Get base URL for OpenAI client."""
51+
return f"https://bedrock-runtime.{self.config.region_name}.amazonaws.com/openai/v1"
10952

110-
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
53+
async def list_provider_model_ids(self) -> Iterable[str]:
54+
"""
55+
Bedrock's OpenAI-compatible endpoint does not support the /v1/models endpoint.
56+
Returns empty list since models must be pre-registered in the config.
57+
"""
58+
return []
11159

112-
# Convert foundation model ID to inference profile ID
113-
region_name = self.client.meta.region_name
114-
inference_profile_id = _to_inference_profile_id(bedrock_model, region_name)
60+
async def register_model(self, model: Model) -> Model:
61+
"""
62+
Register a model with the Bedrock provider.
11563
116-
return {
117-
"modelId": inference_profile_id,
118-
"body": json.dumps(
119-
{
120-
"prompt": prompt,
121-
**options,
122-
}
123-
),
124-
}
64+
Bedrock doesn't support dynamic model listing via /v1/models, so we skip
65+
the availability check and accept all models registered in the config.
66+
"""
67+
return model
12568

12669
async def openai_embeddings(
12770
self,
@@ -131,11 +74,14 @@ async def openai_embeddings(
13174
dimensions: int | None = None,
13275
user: str | None = None,
13376
) -> OpenAIEmbeddingsResponse:
134-
raise NotImplementedError()
77+
"""Bedrock's OpenAI-compatible API does not support the /v1/embeddings endpoint."""
78+
raise NotImplementedError(
79+
"Bedrock's OpenAI-compatible API does not support /v1/embeddings endpoint. "
80+
"See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html"
81+
)
13582

13683
async def openai_completion(
13784
self,
138-
# Standard OpenAI completion parameters
13985
model: str,
14086
prompt: str | list[str] | list[int] | list[list[int]],
14187
best_of: int | None = None,
@@ -153,13 +99,16 @@ async def openai_completion(
15399
temperature: float | None = None,
154100
top_p: float | None = None,
155101
user: str | None = None,
156-
# vLLM-specific parameters
157102
guided_choice: list[str] | None = None,
158103
prompt_logprobs: int | None = None,
159-
# for fill-in-the-middle type completion
160104
suffix: str | None = None,
161105
) -> OpenAICompletion:
162-
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
106+
"""Bedrock's OpenAI-compatible API does not support the /v1/completions endpoint."""
107+
raise NotImplementedError(
108+
"Bedrock's OpenAI-compatible API does not support /v1/completions endpoint. "
109+
"Only /v1/chat/completions is supported. "
110+
"See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html"
111+
)
163112

164113
async def openai_chat_completion(
165114
self,
@@ -187,4 +136,43 @@ async def openai_chat_completion(
187136
top_p: float | None = None,
188137
user: str | None = None,
189138
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
190-
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")
139+
"""Override to enable streaming usage metrics and handle authentication errors."""
140+
# Enable streaming usage metrics when telemetry is active
141+
if stream and get_current_span() is not None:
142+
if stream_options is None:
143+
stream_options = {"include_usage": True}
144+
elif "include_usage" not in stream_options:
145+
stream_options = {**stream_options, "include_usage": True}
146+
147+
# Wrap call in try/except to catch authentication errors
148+
try:
149+
return await super().openai_chat_completion(
150+
model=model,
151+
messages=messages,
152+
frequency_penalty=frequency_penalty,
153+
function_call=function_call,
154+
functions=functions,
155+
logit_bias=logit_bias,
156+
logprobs=logprobs,
157+
max_completion_tokens=max_completion_tokens,
158+
max_tokens=max_tokens,
159+
n=n,
160+
parallel_tool_calls=parallel_tool_calls,
161+
presence_penalty=presence_penalty,
162+
response_format=response_format,
163+
seed=seed,
164+
stop=stop,
165+
stream=stream,
166+
stream_options=stream_options,
167+
temperature=temperature,
168+
tool_choice=tool_choice,
169+
tools=tools,
170+
top_logprobs=top_logprobs,
171+
top_p=top_p,
172+
user=user,
173+
)
174+
except AuthenticationError as e:
175+
raise ValueError(
176+
f"AWS Bedrock authentication failed: {e.message}. "
177+
"Please check your API key in the provider config or x-llamastack-provider-data header."
178+
) from e

llama_stack/providers/remote/inference/bedrock/config.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,33 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
7+
import os
88

9+
from pydantic import BaseModel, Field
910

10-
class BedrockConfig(BedrockBaseConfig):
11-
pass
11+
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
12+
13+
14+
class BedrockProviderDataValidator(BaseModel):
15+
aws_bedrock_api_key: str | None = Field(
16+
default=None,
17+
description="API key for Amazon Bedrock",
18+
)
19+
20+
21+
class BedrockConfig(RemoteInferenceProviderConfig):
22+
api_key: str | None = Field(
23+
default_factory=lambda: os.getenv("AWS_BEDROCK_API_KEY"),
24+
description="Amazon Bedrock API key",
25+
)
26+
region_name: str = Field(
27+
default_factory=lambda: os.getenv("AWS_DEFAULT_REGION", "us-east-2"),
28+
description="AWS Region for the Bedrock Runtime endpoint",
29+
)
30+
31+
@classmethod
32+
def sample_run_config(cls, **kwargs):
33+
return {
34+
"api_key": "${env.AWS_BEDROCK_API_KEY:=}",
35+
"region_name": "${env.AWS_DEFAULT_REGION:=us-east-2}",
36+
}

0 commit comments

Comments
 (0)