Skip to content

Commit 7024e56

Browse files
committed
feat: add OpenAI-compatible Bedrock provider with error handling
Implements AWS Bedrock inference provider using OpenAI-compatible endpoint for Llama models available through Bedrock. Changes: - Add BedrockInferenceAdapter using OpenAIMixin base - Configure region-specific endpoint URLs - Add NotImplementedError stubs for unsupported endpoints - Implement authentication error handling with helpful messages - Remove unused models.py file - Add comprehensive unit tests (12 total) - Add provider registry configuration
1 parent 968c364 commit 7024e56

File tree

12 files changed

+282
-185
lines changed

12 files changed

+282
-185
lines changed
Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
---
2-
description: "AWS Bedrock inference provider for accessing various AI models through AWS's managed service."
2+
description: "AWS Bedrock inference provider using OpenAI compatible endpoint."
33
sidebar_label: Remote - Bedrock
44
title: remote::bedrock
55
---
@@ -8,27 +8,20 @@ title: remote::bedrock
88

99
## Description
1010

11-
AWS Bedrock inference provider for accessing various AI models through AWS's managed service.
11+
AWS Bedrock inference provider using OpenAI compatible endpoint.
1212

1313
## Configuration
1414

1515
| Field | Type | Required | Default | Description |
1616
|-------|------|----------|---------|-------------|
1717
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
1818
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
19-
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
20-
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
21-
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
22-
| `region_name` | `str \| None` | No | | The default AWS Region to use, for example, us-west-1 or us-west-2.Default use environment variable: AWS_DEFAULT_REGION |
23-
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
24-
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
25-
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
26-
| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
27-
| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
28-
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
19+
| `api_key` | `str \| None` | No | | Amazon Bedrock API key |
20+
| `region_name` | `<class 'str'>` | No | us-east-2 | AWS Region for the Bedrock Runtime endpoint |
2921

3022
## Sample Configuration
3123

3224
```yaml
33-
{}
25+
api_key: ${env.AWS_BEDROCK_API_KEY:=}
26+
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
3427
```

llama_stack/distributions/ci-tests/run.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ providers:
4747
api_key: ${env.TOGETHER_API_KEY:=}
4848
- provider_id: bedrock
4949
provider_type: remote::bedrock
50+
config:
51+
api_key: ${env.AWS_BEDROCK_API_KEY:=}
52+
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
5053
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
5154
provider_type: remote::nvidia
5255
config:

llama_stack/distributions/starter-gpu/run.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ providers:
4747
api_key: ${env.TOGETHER_API_KEY:=}
4848
- provider_id: bedrock
4949
provider_type: remote::bedrock
50+
config:
51+
api_key: ${env.AWS_BEDROCK_API_KEY:=}
52+
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
5053
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
5154
provider_type: remote::nvidia
5255
config:

llama_stack/distributions/starter/run.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ providers:
4747
api_key: ${env.TOGETHER_API_KEY:=}
4848
- provider_id: bedrock
4949
provider_type: remote::bedrock
50+
config:
51+
api_key: ${env.AWS_BEDROCK_API_KEY:=}
52+
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
5053
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
5154
provider_type: remote::nvidia
5255
config:

llama_stack/providers/registry/inference.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,11 @@ def available_providers() -> list[ProviderSpec]:
131131
api=Api.inference,
132132
adapter_type="bedrock",
133133
provider_type="remote::bedrock",
134-
pip_packages=["boto3"],
134+
pip_packages=[],
135135
module="llama_stack.providers.remote.inference.bedrock",
136136
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
137-
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
137+
provider_data_validator="llama_stack.providers.remote.inference.bedrock.config.BedrockProviderDataValidator",
138+
description="AWS Bedrock inference provider using OpenAI compatible endpoint.",
138139
),
139140
RemoteProviderSpec(
140141
api=Api.inference,

llama_stack/providers/remote/inference/bedrock/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ async def get_adapter_impl(config: BedrockConfig, _deps):
1111

1212
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
1313

14-
impl = BedrockInferenceAdapter(config)
14+
impl = BedrockInferenceAdapter(config=config)
1515

1616
await impl.initialize()
1717

llama_stack/providers/remote/inference/bedrock/bedrock.py

Lines changed: 69 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -4,139 +4,106 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
import json
8-
from collections.abc import AsyncIterator
7+
from collections.abc import AsyncIterator, Iterable
98

10-
from botocore.client import BaseClient
9+
from openai import AuthenticationError
1110

1211
from llama_stack.apis.inference import (
13-
ChatCompletionRequest,
14-
Inference,
12+
Model,
13+
OpenAIChatCompletion,
14+
OpenAIChatCompletionChunk,
1515
OpenAIChatCompletionRequestWithExtraBody,
16+
OpenAICompletion,
1617
OpenAICompletionRequestWithExtraBody,
1718
OpenAIEmbeddingsRequestWithExtraBody,
1819
OpenAIEmbeddingsResponse,
1920
)
20-
from llama_stack.apis.inference.inference import (
21-
OpenAIChatCompletion,
22-
OpenAIChatCompletionChunk,
23-
OpenAICompletion,
24-
)
25-
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
26-
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
27-
from llama_stack.providers.utils.inference.model_registry import (
28-
ModelRegistryHelper,
29-
)
30-
from llama_stack.providers.utils.inference.openai_compat import (
31-
get_sampling_strategy_options,
32-
)
33-
from llama_stack.providers.utils.inference.prompt_adapter import (
34-
chat_completion_request_to_prompt,
35-
)
36-
37-
from .models import MODEL_ENTRIES
38-
39-
REGION_PREFIX_MAP = {
40-
"us": "us.",
41-
"eu": "eu.",
42-
"ap": "ap.",
43-
}
44-
45-
46-
def _get_region_prefix(region: str | None) -> str:
47-
# AWS requires region prefixes for inference profiles
48-
if region is None:
49-
return "us." # default to US when we don't know
50-
51-
# Handle case insensitive region matching
52-
region_lower = region.lower()
53-
for prefix in REGION_PREFIX_MAP:
54-
if region_lower.startswith(f"{prefix}-"):
55-
return REGION_PREFIX_MAP[prefix]
56-
57-
# Fallback to US for anything we don't recognize
58-
return "us."
59-
60-
61-
def _to_inference_profile_id(model_id: str, region: str = None) -> str:
62-
# Return ARNs unchanged
63-
if model_id.startswith("arn:"):
64-
return model_id
65-
66-
# Return inference profile IDs that already have regional prefixes
67-
if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()):
68-
return model_id
69-
70-
# Default to US East when no region is provided
71-
if region is None:
72-
region = "us-east-1"
73-
74-
return _get_region_prefix(region) + model_id
21+
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
22+
from llama_stack.providers.utils.telemetry.tracing import get_current_span
7523

24+
from .config import BedrockConfig
7625

77-
class BedrockInferenceAdapter(
78-
ModelRegistryHelper,
79-
Inference,
80-
):
81-
def __init__(self, config: BedrockConfig) -> None:
82-
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
83-
self._config = config
84-
self._client = None
8526

86-
@property
87-
def client(self) -> BaseClient:
88-
if self._client is None:
89-
self._client = create_bedrock_client(self._config)
90-
return self._client
27+
class BedrockInferenceAdapter(OpenAIMixin):
28+
"""
29+
Adapter for AWS Bedrock's OpenAI-compatible API endpoints.
9130
92-
async def initialize(self) -> None:
93-
pass
31+
Supports Llama models across regions and GPT-OSS models (us-west-2 only).
9432
95-
async def shutdown(self) -> None:
96-
if self._client is not None:
97-
self._client.close()
33+
Note: Bedrock's OpenAI-compatible endpoint does not support /v1/models
34+
for dynamic model discovery. Models must be pre-registered in the config.
35+
"""
9836

99-
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict:
100-
bedrock_model = request.model
37+
config: BedrockConfig
38+
provider_data_api_key_field: str = "aws_bedrock_api_key"
10139

102-
sampling_params = request.sampling_params
103-
options = get_sampling_strategy_options(sampling_params)
40+
def get_api_key(self) -> str:
41+
"""Get API key for OpenAI client."""
42+
if not self.config.api_key:
43+
raise ValueError(
44+
"API key is not set. Please provide a valid API key in the "
45+
"provider config or via AWS_BEDROCK_API_KEY environment variable."
46+
)
47+
return self.config.api_key
10448

105-
if sampling_params.max_tokens:
106-
options["max_gen_len"] = sampling_params.max_tokens
107-
if sampling_params.repetition_penalty > 0:
108-
options["repetition_penalty"] = sampling_params.repetition_penalty
49+
def get_base_url(self) -> str:
50+
"""Get base URL for OpenAI client."""
51+
return f"https://bedrock-runtime.{self.config.region_name}.amazonaws.com/openai/v1"
10952

110-
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
53+
async def list_provider_model_ids(self) -> Iterable[str]:
54+
"""
55+
Bedrock's OpenAI-compatible endpoint does not support the /v1/models endpoint.
56+
Returns empty list since models must be pre-registered in the config.
57+
"""
58+
return []
11159

112-
# Convert foundation model ID to inference profile ID
113-
region_name = self.client.meta.region_name
114-
inference_profile_id = _to_inference_profile_id(bedrock_model, region_name)
60+
async def register_model(self, model: Model) -> Model:
61+
"""
62+
Register a model with the Bedrock provider.
11563
116-
return {
117-
"modelId": inference_profile_id,
118-
"body": json.dumps(
119-
{
120-
"prompt": prompt,
121-
**options,
122-
}
123-
),
124-
}
64+
Bedrock doesn't support dynamic model listing via /v1/models, so we skip
65+
the availability check and accept all models registered in the config.
66+
"""
67+
return model
12568

12669
async def openai_embeddings(
12770
self,
12871
params: OpenAIEmbeddingsRequestWithExtraBody,
12972
) -> OpenAIEmbeddingsResponse:
130-
raise NotImplementedError()
73+
"""Bedrock's OpenAI-compatible API does not support the /v1/embeddings endpoint."""
74+
raise NotImplementedError(
75+
"Bedrock's OpenAI-compatible API does not support /v1/embeddings endpoint. "
76+
"See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html"
77+
)
13178

13279
async def openai_completion(
13380
self,
13481
params: OpenAICompletionRequestWithExtraBody,
13582
) -> OpenAICompletion:
136-
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
83+
"""Bedrock's OpenAI-compatible API does not support the /v1/completions endpoint."""
84+
raise NotImplementedError(
85+
"Bedrock's OpenAI-compatible API does not support /v1/completions endpoint. "
86+
"Only /v1/chat/completions is supported. "
87+
"See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html"
88+
)
13789

13890
async def openai_chat_completion(
13991
self,
14092
params: OpenAIChatCompletionRequestWithExtraBody,
14193
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
142-
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")
94+
"""Override to enable streaming usage metrics and handle authentication errors."""
95+
# Enable streaming usage metrics when telemetry is active
96+
if params.stream and get_current_span() is not None:
97+
if params.stream_options is None:
98+
params.stream_options = {"include_usage": True}
99+
elif "include_usage" not in params.stream_options:
100+
params.stream_options = {**params.stream_options, "include_usage": True}
101+
102+
# Wrap call in try/except to catch authentication errors
103+
try:
104+
return await super().openai_chat_completion(params=params)
105+
except AuthenticationError as e:
106+
raise ValueError(
107+
f"AWS Bedrock authentication failed: {e.message}. "
108+
"Please check your API key in the provider config or x-llamastack-provider-data header."
109+
) from e

llama_stack/providers/remote/inference/bedrock/config.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,33 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
7+
import os
88

9+
from pydantic import BaseModel, Field
910

10-
class BedrockConfig(BedrockBaseConfig):
11-
pass
11+
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
12+
13+
14+
class BedrockProviderDataValidator(BaseModel):
15+
aws_bedrock_api_key: str | None = Field(
16+
default=None,
17+
description="API key for Amazon Bedrock",
18+
)
19+
20+
21+
class BedrockConfig(RemoteInferenceProviderConfig):
22+
api_key: str | None = Field(
23+
default_factory=lambda: os.getenv("AWS_BEDROCK_API_KEY"),
24+
description="Amazon Bedrock API key",
25+
)
26+
region_name: str = Field(
27+
default_factory=lambda: os.getenv("AWS_DEFAULT_REGION", "us-east-2"),
28+
description="AWS Region for the Bedrock Runtime endpoint",
29+
)
30+
31+
@classmethod
32+
def sample_run_config(cls, **kwargs):
33+
return {
34+
"api_key": "${env.AWS_BEDROCK_API_KEY:=}",
35+
"region_name": "${env.AWS_DEFAULT_REGION:=us-east-2}",
36+
}

llama_stack/providers/remote/inference/bedrock/models.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

0 commit comments

Comments
 (0)