Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions docs/docs/providers/inference/remote_bedrock.mdx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
description: "AWS Bedrock inference provider for accessing various AI models through AWS's managed service."
description: "AWS Bedrock inference provider using OpenAI compatible endpoint."
sidebar_label: Remote - Bedrock
title: remote::bedrock
---
Expand All @@ -8,27 +8,20 @@ title: remote::bedrock

## Description

AWS Bedrock inference provider for accessing various AI models through AWS's managed service.
AWS Bedrock inference provider using OpenAI compatible endpoint.

## Configuration

| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
| `region_name` | `str \| None` | No | | The default AWS Region to use, for example, us-west-1 or us-west-2.Default use environment variable: AWS_DEFAULT_REGION |
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
| `api_key` | `pydantic.types.SecretStr \| None` | No | | Authentication credential for the provider |
| `region_name` | `<class 'str'>` | No | us-east-2 | AWS Region for the Bedrock Runtime endpoint |

## Sample Configuration

```yaml
{}
api_key: ${env.AWS_BEDROCK_API_KEY:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
```
4 changes: 2 additions & 2 deletions src/llama_stack/core/routers/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ async def openai_completion(

response = await provider.openai_completion(params)
response.model = request_model_id
if self.telemetry_enabled:
if self.telemetry_enabled and response.usage is not None:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
Expand Down Expand Up @@ -253,7 +253,7 @@ async def openai_chat_completion(
if self.store:
asyncio.create_task(self.store.store_chat_completion(response, params.messages))

if self.telemetry_enabled:
if self.telemetry_enabled and response.usage is not None:
metrics = self._construct_metrics(
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
Expand Down
3 changes: 3 additions & 0 deletions src/llama_stack/distributions/ci-tests/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ providers:
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
Expand Down
3 changes: 3 additions & 0 deletions src/llama_stack/distributions/starter-gpu/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ providers:
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
Expand Down
3 changes: 3 additions & 0 deletions src/llama_stack/distributions/starter/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ providers:
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
Expand Down
5 changes: 3 additions & 2 deletions src/llama_stack/providers/registry/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,11 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="bedrock",
provider_type="remote::bedrock",
pip_packages=["boto3"],
pip_packages=[],
module="llama_stack.providers.remote.inference.bedrock",
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
provider_data_validator="llama_stack.providers.remote.inference.bedrock.config.BedrockProviderDataValidator",
description="AWS Bedrock inference provider using OpenAI compatible endpoint.",
),
RemoteProviderSpec(
api=Api.inference,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ async def get_adapter_impl(config: BedrockConfig, _deps):

assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"

impl = BedrockInferenceAdapter(config)
impl = BedrockInferenceAdapter(config=config)

await impl.initialize()

Expand Down
191 changes: 88 additions & 103 deletions src/llama_stack/providers/remote/inference/bedrock/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,139 +4,124 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Iterable

from botocore.client import BaseClient
from openai import AuthenticationError

from llama_stack.apis.inference import (
ChatCompletionRequest,
Inference,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
)
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_strategy_options,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)

from .models import MODEL_ENTRIES

REGION_PREFIX_MAP = {
"us": "us.",
"eu": "eu.",
"ap": "ap.",
}


def _get_region_prefix(region: str | None) -> str:
# AWS requires region prefixes for inference profiles
if region is None:
return "us." # default to US when we don't know

# Handle case insensitive region matching
region_lower = region.lower()
for prefix in REGION_PREFIX_MAP:
if region_lower.startswith(f"{prefix}-"):
return REGION_PREFIX_MAP[prefix]

# Fallback to US for anything we don't recognize
return "us."


def _to_inference_profile_id(model_id: str, region: str = None) -> str:
# Return ARNs unchanged
if model_id.startswith("arn:"):
return model_id

# Return inference profile IDs that already have regional prefixes
if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()):
return model_id

# Default to US East when no region is provided
if region is None:
region = "us-east-1"

return _get_region_prefix(region) + model_id

from llama_stack.core.telemetry.tracing import get_current_span
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin

class BedrockInferenceAdapter(
ModelRegistryHelper,
Inference,
):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self._config = config
self._client = None
from .config import BedrockConfig

@property
def client(self) -> BaseClient:
if self._client is None:
self._client = create_bedrock_client(self._config)
return self._client
logger = get_logger(name=__name__, category="inference::bedrock")

async def initialize(self) -> None:
pass

async def shutdown(self) -> None:
if self._client is not None:
self._client.close()
class BedrockInferenceAdapter(OpenAIMixin):
"""
Adapter for AWS Bedrock's OpenAI-compatible API endpoints.

async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict:
bedrock_model = request.model
Supports Llama models across regions and GPT-OSS models (us-west-2 only).

sampling_params = request.sampling_params
options = get_sampling_strategy_options(sampling_params)
Note: Bedrock's OpenAI-compatible endpoint does not support /v1/models
for dynamic model discovery. Models must be pre-registered in the config.
"""

if sampling_params.max_tokens:
options["max_gen_len"] = sampling_params.max_tokens
if sampling_params.repetition_penalty > 0:
options["repetition_penalty"] = sampling_params.repetition_penalty
config: BedrockConfig
provider_data_api_key_field: str = "aws_bedrock_api_key"

prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
def get_base_url(self) -> str:
"""Get base URL for OpenAI client."""
return f"https://bedrock-runtime.{self.config.region_name}.amazonaws.com/openai/v1"

# Convert foundation model ID to inference profile ID
region_name = self.client.meta.region_name
inference_profile_id = _to_inference_profile_id(bedrock_model, region_name)
async def list_provider_model_ids(self) -> Iterable[str]:
"""
Bedrock's OpenAI-compatible endpoint does not support the /v1/models endpoint.
Returns empty list since models must be pre-registered in the config.
"""
return []

return {
"modelId": inference_profile_id,
"body": json.dumps(
{
"prompt": prompt,
**options,
}
),
}
async def check_model_availability(self, model: str) -> bool:
"""
Bedrock doesn't support dynamic model listing via /v1/models.
Always return True to accept all models registered in the config.
"""
return True

async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
"""Bedrock's OpenAI-compatible API does not support the /v1/embeddings endpoint."""
raise NotImplementedError(
"Bedrock's OpenAI-compatible API does not support /v1/embeddings endpoint. "
"See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html"
)

async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
"""Bedrock's OpenAI-compatible API does not support the /v1/completions endpoint."""
raise NotImplementedError(
"Bedrock's OpenAI-compatible API does not support /v1/completions endpoint. "
"Only /v1/chat/completions is supported. "
"See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html"
)

async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")
"""Override to enable streaming usage metrics and handle authentication errors."""
# Enable streaming usage metrics when telemetry is active
if params.stream and get_current_span() is not None:
if params.stream_options is None:
params.stream_options = {"include_usage": True}
elif "include_usage" not in params.stream_options:
params.stream_options = {**params.stream_options, "include_usage": True}

try:
logger.debug(f"Calling Bedrock OpenAI API with model={params.model}, stream={params.stream}")
result = await super().openai_chat_completion(params=params)
logger.debug(f"Bedrock API returned: {type(result).__name__ if result is not None else 'None'}")

if result is None:
logger.error(f"Bedrock OpenAI client returned None for model={params.model}, stream={params.stream}")
raise RuntimeError(
f"Bedrock API returned no response for model '{params.model}'. "
"This may indicate the model is not supported or a network/API issue occurred."
)

return result
except AuthenticationError as e:
error_msg = str(e)

# Check if this is a token expiration error
if "expired" in error_msg.lower() or "Bearer Token has expired" in error_msg:
logger.error(f"AWS Bedrock authentication token expired: {error_msg}")
raise ValueError(
"AWS Bedrock authentication failed: Bearer token has expired. "
"The AWS_BEDROCK_API_KEY environment variable contains an expired pre-signed URL. "
"Please refresh your token by generating a new pre-signed URL with AWS credentials. "
"Refer to AWS Bedrock documentation for details on OpenAI-compatible endpoints."
) from e
else:
logger.error(f"AWS Bedrock authentication failed: {error_msg}")
raise ValueError(
f"AWS Bedrock authentication failed: {error_msg}. "
"Please verify your API key is correct in the provider config or x-llamastack-provider-data header. "
"The API key should be a valid AWS pre-signed URL for Bedrock's OpenAI-compatible endpoint."
) from e
except Exception as e:
logger.error(f"Unexpected error calling Bedrock API: {type(e).__name__}: {e}", exc_info=True)
raise
27 changes: 24 additions & 3 deletions src/llama_stack/providers/remote/inference/bedrock/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,29 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
import os

from pydantic import BaseModel, Field

class BedrockConfig(BedrockBaseConfig):
pass
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig


class BedrockProviderDataValidator(BaseModel):
aws_bedrock_api_key: str | None = Field(
default=None,
description="API key for Amazon Bedrock",
)


class BedrockConfig(RemoteInferenceProviderConfig):
region_name: str = Field(
default_factory=lambda: os.getenv("AWS_DEFAULT_REGION", "us-east-2"),
description="AWS Region for the Bedrock Runtime endpoint",
)

@classmethod
def sample_run_config(cls, **kwargs):
return {
"api_key": "${env.AWS_BEDROCK_API_KEY:=}",
"region_name": "${env.AWS_DEFAULT_REGION:=us-east-2}",
}
Loading
Loading