Skip to content

Commit 14919c1

Browse files
committed
feat: add OpenAI-compatible Bedrock provider
Implements AWS Bedrock inference provider using OpenAI-compatible endpoint for Llama models available through Bedrock. Changes: - Add BedrockInferenceAdapter using LiteLLMOpenAIMixin base - Configure region-specific endpoint URLs - Support cross-region inference profiles with retry logic - Implement comprehensive unit tests and integration tests - Add provider registry configuration with litellm dependency
1 parent 96886af commit 14919c1

File tree

13 files changed

+532
-191
lines changed

13 files changed

+532
-191
lines changed
Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
---
2-
description: "AWS Bedrock inference provider for accessing various AI models through AWS's managed service."
2+
description: "AWS Bedrock inference provider using OpenAI compatible endpoint."
33
sidebar_label: Remote - Bedrock
44
title: remote::bedrock
55
---
@@ -8,27 +8,20 @@ title: remote::bedrock
88

99
## Description
1010

11-
AWS Bedrock inference provider for accessing various AI models through AWS's managed service.
11+
AWS Bedrock inference provider using OpenAI compatible endpoint.
1212

1313
## Configuration
1414

1515
| Field | Type | Required | Default | Description |
1616
|-------|------|----------|---------|-------------|
1717
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
1818
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
19-
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
20-
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
21-
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
22-
| `region_name` | `str \| None` | No | | The default AWS Region to use, for example, us-west-1 or us-west-2.Default use environment variable: AWS_DEFAULT_REGION |
23-
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
24-
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
25-
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
26-
| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
27-
| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
28-
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
19+
| `api_key` | `str \| None` | No | | Amazon Bedrock API key |
20+
| `region_name` | `<class 'str'>` | No | us-east-2 | AWS Region for the Bedrock Runtime endpoint |
2921

3022
## Sample Configuration
3123

3224
```yaml
33-
{}
25+
api_key: ${env.AWS_BEDROCK_API_KEY:=}
26+
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
3427
```

llama_stack/distributions/ci-tests/run.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ providers:
4747
api_key: ${env.TOGETHER_API_KEY:=}
4848
- provider_id: bedrock
4949
provider_type: remote::bedrock
50+
config:
51+
api_key: ${env.AWS_BEDROCK_API_KEY:=}
52+
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
5053
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
5154
provider_type: remote::nvidia
5255
config:

llama_stack/distributions/starter-gpu/run.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ providers:
4747
api_key: ${env.TOGETHER_API_KEY:=}
4848
- provider_id: bedrock
4949
provider_type: remote::bedrock
50+
config:
51+
api_key: ${env.AWS_BEDROCK_API_KEY:=}
52+
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
5053
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
5154
provider_type: remote::nvidia
5255
config:

llama_stack/distributions/starter/run.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ providers:
4747
api_key: ${env.TOGETHER_API_KEY:=}
4848
- provider_id: bedrock
4949
provider_type: remote::bedrock
50+
config:
51+
api_key: ${env.AWS_BEDROCK_API_KEY:=}
52+
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
5053
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
5154
provider_type: remote::nvidia
5255
config:

llama_stack/providers/registry/inference.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,11 @@ def available_providers() -> list[ProviderSpec]:
131131
api=Api.inference,
132132
adapter_type="bedrock",
133133
provider_type="remote::bedrock",
134-
pip_packages=["boto3"],
134+
pip_packages=[],
135135
module="llama_stack.providers.remote.inference.bedrock",
136136
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
137-
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
137+
provider_data_validator="llama_stack.providers.remote.inference.bedrock.config.BedrockProviderDataValidator",
138+
description="AWS Bedrock inference provider using OpenAI compatible endpoint.",
138139
),
139140
RemoteProviderSpec(
140141
api=Api.inference,

llama_stack/providers/remote/inference/bedrock/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ async def get_adapter_impl(config: BedrockConfig, _deps):
1111

1212
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
1313

14-
impl = BedrockInferenceAdapter(config)
14+
impl = BedrockInferenceAdapter(config=config)
1515

1616
await impl.initialize()
1717

llama_stack/providers/remote/inference/bedrock/bedrock.py

Lines changed: 106 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -4,162 +4,69 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
import json
8-
from collections.abc import AsyncIterator
7+
from collections.abc import AsyncIterator, Iterable
98
from typing import Any
109

11-
from botocore.client import BaseClient
10+
from openai import AuthenticationError, BadRequestError, NotFoundError
1211

1312
from llama_stack.apis.inference import (
14-
ChatCompletionRequest,
15-
Inference,
16-
OpenAIEmbeddingsResponse,
17-
)
18-
from llama_stack.apis.inference.inference import (
13+
Model,
1914
OpenAIChatCompletion,
2015
OpenAIChatCompletionChunk,
21-
OpenAICompletion,
2216
OpenAIMessageParam,
2317
OpenAIResponseFormatParam,
2418
)
25-
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
26-
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
27-
from llama_stack.providers.utils.inference.model_registry import (
28-
ModelRegistryHelper,
29-
)
30-
from llama_stack.providers.utils.inference.openai_compat import (
31-
get_sampling_strategy_options,
32-
)
33-
from llama_stack.providers.utils.inference.prompt_adapter import (
34-
chat_completion_request_to_prompt,
35-
)
36-
37-
from .models import MODEL_ENTRIES
38-
39-
REGION_PREFIX_MAP = {
40-
"us": "us.",
41-
"eu": "eu.",
42-
"ap": "ap.",
43-
}
44-
45-
46-
def _get_region_prefix(region: str | None) -> str:
47-
# AWS requires region prefixes for inference profiles
48-
if region is None:
49-
return "us." # default to US when we don't know
50-
51-
# Handle case insensitive region matching
52-
region_lower = region.lower()
53-
for prefix in REGION_PREFIX_MAP:
54-
if region_lower.startswith(f"{prefix}-"):
55-
return REGION_PREFIX_MAP[prefix]
19+
from llama_stack.log import get_logger
20+
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
21+
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
22+
from llama_stack.providers.utils.telemetry.tracing import get_current_span
5623

57-
# Fallback to US for anything we don't recognize
58-
return "us."
24+
from .config import BedrockConfig
5925

26+
logger = get_logger(name=__name__, category="inference::bedrock")
6027

61-
def _to_inference_profile_id(model_id: str, region: str = None) -> str:
62-
# Return ARNs unchanged
63-
if model_id.startswith("arn:"):
64-
return model_id
6528

66-
# Return inference profile IDs that already have regional prefixes
67-
if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()):
68-
return model_id
29+
class BedrockInferenceAdapter(OpenAIMixin):
30+
"""
31+
Adapter for AWS Bedrock's OpenAI-compatible API endpoints.
6932
70-
# Default to US East when no region is provided
71-
if region is None:
72-
region = "us-east-1"
33+
Supports Llama models across regions and GPT-OSS models (us-west-2 only).
7334
74-
return _get_region_prefix(region) + model_id
35+
Note: Bedrock's OpenAI-compatible endpoint does not support /v1/models
36+
for dynamic model discovery. Models must be pre-registered in the config.
37+
"""
7538

39+
config: BedrockConfig
40+
provider_data_api_key_field: str = "aws_bedrock_api_key"
7641

77-
class BedrockInferenceAdapter(
78-
ModelRegistryHelper,
79-
Inference,
80-
):
81-
def __init__(self, config: BedrockConfig) -> None:
82-
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
83-
self._config = config
84-
self._client = None
42+
def get_api_key(self) -> str:
43+
"""Get API key for OpenAI client."""
44+
if not self.config.api_key:
45+
raise ValueError(
46+
"API key is not set. Please provide a valid API key in the "
47+
"provider config or via AWS_BEDROCK_API_KEY environment variable."
48+
)
49+
return self.config.api_key
8550

86-
@property
87-
def client(self) -> BaseClient:
88-
if self._client is None:
89-
self._client = create_bedrock_client(self._config)
90-
return self._client
51+
def get_base_url(self) -> str:
52+
"""Get base URL for OpenAI client."""
53+
return f"https://bedrock-runtime.{self.config.region_name}.amazonaws.com/openai/v1"
9154

92-
async def initialize(self) -> None:
93-
pass
55+
async def list_provider_model_ids(self) -> Iterable[str]:
56+
"""
57+
Bedrock's OpenAI-compatible endpoint does not support the /v1/models endpoint.
58+
Returns empty list since models must be pre-registered in the config.
59+
"""
60+
return []
9461

95-
async def shutdown(self) -> None:
96-
if self._client is not None:
97-
self._client.close()
98-
99-
async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict:
100-
bedrock_model = request.model
101-
102-
sampling_params = request.sampling_params
103-
options = get_sampling_strategy_options(sampling_params)
104-
105-
if sampling_params.max_tokens:
106-
options["max_gen_len"] = sampling_params.max_tokens
107-
if sampling_params.repetition_penalty > 0:
108-
options["repetition_penalty"] = sampling_params.repetition_penalty
109-
110-
prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
111-
112-
# Convert foundation model ID to inference profile ID
113-
region_name = self.client.meta.region_name
114-
inference_profile_id = _to_inference_profile_id(bedrock_model, region_name)
115-
116-
return {
117-
"modelId": inference_profile_id,
118-
"body": json.dumps(
119-
{
120-
"prompt": prompt,
121-
**options,
122-
}
123-
),
124-
}
125-
126-
async def openai_embeddings(
127-
self,
128-
model: str,
129-
input: str | list[str],
130-
encoding_format: str | None = "float",
131-
dimensions: int | None = None,
132-
user: str | None = None,
133-
) -> OpenAIEmbeddingsResponse:
134-
raise NotImplementedError()
62+
async def register_model(self, model: Model) -> Model:
63+
"""
64+
Register a model with the Bedrock provider.
13565
136-
async def openai_completion(
137-
self,
138-
# Standard OpenAI completion parameters
139-
model: str,
140-
prompt: str | list[str] | list[int] | list[list[int]],
141-
best_of: int | None = None,
142-
echo: bool | None = None,
143-
frequency_penalty: float | None = None,
144-
logit_bias: dict[str, float] | None = None,
145-
logprobs: bool | None = None,
146-
max_tokens: int | None = None,
147-
n: int | None = None,
148-
presence_penalty: float | None = None,
149-
seed: int | None = None,
150-
stop: str | list[str] | None = None,
151-
stream: bool | None = None,
152-
stream_options: dict[str, Any] | None = None,
153-
temperature: float | None = None,
154-
top_p: float | None = None,
155-
user: str | None = None,
156-
# vLLM-specific parameters
157-
guided_choice: list[str] | None = None,
158-
prompt_logprobs: int | None = None,
159-
# for fill-in-the-middle type completion
160-
suffix: str | None = None,
161-
) -> OpenAICompletion:
162-
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
66+
Bedrock doesn't support dynamic model listing via /v1/models, so we skip
67+
the availability check and accept all models registered in the config.
68+
"""
69+
return model
16370

16471
async def openai_chat_completion(
16572
self,
@@ -187,4 +94,67 @@ async def openai_chat_completion(
18794
top_p: float | None = None,
18895
user: str | None = None,
18996
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
190-
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")
97+
"""Override to add Bedrock-specific model ID handling and streaming options."""
98+
# Get the provider model ID from model store
99+
model_obj = await self.model_store.get_model(model) # type: ignore[attr-defined]
100+
provider_model_id: str = model_obj.provider_resource_id or model
101+
102+
# Bedrock OpenAI-compatible endpoint expects base model IDs (e.g. "openai.gpt-oss-20b-1:0").
103+
# Cross-region inference profile IDs prefixed with "us." are not recognized by the endpoint.
104+
# Normalize to base model ID, then try both base and prefixed forms for compatibility.
105+
base_model_id = provider_model_id[3:] if provider_model_id.startswith("us.") else provider_model_id
106+
candidate_models = [base_model_id, f"us.{base_model_id}"]
107+
108+
# Enable streaming usage metrics when telemetry is active
109+
if stream and get_current_span() is not None:
110+
if stream_options is None:
111+
stream_options = {"include_usage": True}
112+
elif "include_usage" not in stream_options:
113+
stream_options = {**stream_options, "include_usage": True}
114+
115+
# Try candidate model IDs with retry logic
116+
last_error: Exception | None = None
117+
for candidate in candidate_models:
118+
try:
119+
logger.debug(f"Attempting request with model ID: {candidate}")
120+
# Call OpenAI client directly with the candidate model ID
121+
# We can't use super().openai_chat_completion() because it would
122+
# call self._get_provider_model_id() which looks up the model again
123+
params = await prepare_openai_completion_params(
124+
model=candidate,
125+
messages=messages,
126+
frequency_penalty=frequency_penalty,
127+
function_call=function_call,
128+
functions=functions,
129+
logit_bias=logit_bias,
130+
logprobs=logprobs,
131+
max_completion_tokens=max_completion_tokens,
132+
max_tokens=max_tokens,
133+
n=n,
134+
parallel_tool_calls=parallel_tool_calls,
135+
presence_penalty=presence_penalty,
136+
response_format=response_format,
137+
seed=seed,
138+
stop=stop,
139+
stream=stream,
140+
stream_options=stream_options,
141+
temperature=temperature,
142+
tool_choice=tool_choice,
143+
tools=tools,
144+
top_logprobs=top_logprobs,
145+
top_p=top_p,
146+
user=user,
147+
)
148+
resp = await self.client.chat.completions.create(**params)
149+
return await self._maybe_overwrite_id(resp, stream) # type: ignore[no-any-return]
150+
except AuthenticationError as e:
151+
# Authentication errors - no retry with different model IDs
152+
raise ValueError(f"Authentication failed: {str(e)}") from e
153+
except (NotFoundError, BadRequestError) as e:
154+
logger.debug(f"Model ID {candidate} failed: {e}, trying next candidate")
155+
last_error = e
156+
continue
157+
158+
if last_error:
159+
raise last_error
160+
raise RuntimeError("Bedrock chat completion failed")

0 commit comments

Comments
 (0)