diff --git a/examples/bedrock_quickstart.py b/examples/bedrock_quickstart.py new file mode 100644 index 00000000..81af5fb9 --- /dev/null +++ b/examples/bedrock_quickstart.py @@ -0,0 +1,50 @@ +"""Quickstart example for using RLM with AWS Bedrock. + +This example demonstrates how to use RLM with AWS Bedrock via Project Mantle, +Amazon's OpenAI-compatible endpoint for Bedrock models. + +Prerequisites: + 1. Generate a Bedrock API key in AWS Console → Bedrock → API keys + 2. Set environment variables: + export AWS_BEDROCK_API_KEY=your-bedrock-api-key + export AWS_BEDROCK_REGION=us-east-1 # optional, defaults to us-east-1 + +Usage: + python examples/bedrock_quickstart.py +""" + +import os + +from dotenv import load_dotenv + +from rlm import RLM +from rlm.logger import RLMLogger + +load_dotenv() + +# Initialize logger for debugging +logger = RLMLogger(log_dir="./logs") + +# Create RLM instance with Bedrock backend +rlm = RLM( + backend="bedrock", + backend_kwargs={ + "model_name": "qwen.qwen3-coder-30b-a3b-v1:0", # Qwen3 Coder for code tasks + # api_key and region are read from environment variables: + # - AWS_BEDROCK_API_KEY + # - AWS_BEDROCK_REGION (defaults to us-east-1) + }, + environment="docker", # or "local" for direct execution + environment_kwargs={}, + max_depth=1, + logger=logger, + verbose=True, +) + +# Run a simple RLM completion +result = rlm.completion("Print me the first 5 powers of two, each on a newline.") + +print("\n" + "=" * 50) +print("RLM Result:") +print("=" * 50) +print(result) diff --git a/rlm/clients/__init__.py b/rlm/clients/__init__.py index 09d3008b..f20b0551 100644 --- a/rlm/clients/__init__.py +++ b/rlm/clients/__init__.py @@ -57,7 +57,11 @@ def get_client( from rlm.clients.azure_openai import AzureOpenAIClient return AzureOpenAIClient(**backend_kwargs) + elif backend == "bedrock": + from rlm.clients.bedrock import BedrockClient + + return BedrockClient(**backend_kwargs) else: raise ValueError( - f"Unknown backend: {backend}. Supported backends: ['openai', 'vllm', 'portkey', 'openrouter', 'litellm', 'anthropic', 'azure_openai', 'gemini', 'vercel']" + f"Unknown backend: {backend}. Supported backends: ['openai', 'vllm', 'portkey', 'openrouter', 'litellm', 'anthropic', 'azure_openai', 'gemini', 'vercel', 'bedrock']" ) diff --git a/rlm/clients/bedrock.py b/rlm/clients/bedrock.py new file mode 100644 index 00000000..d0139a31 --- /dev/null +++ b/rlm/clients/bedrock.py @@ -0,0 +1,243 @@ +"""AWS Bedrock client for RLM via Project Mantle (OpenAI-compatible endpoint). + +Amazon Bedrock provides an OpenAI-compatible API through Project Mantle, +allowing seamless integration with existing OpenAI SDK code. + +Environment variables: + AWS_BEDROCK_API_KEY: Bedrock API key (Bearer token). + Generate in AWS Console → Bedrock → API keys. + AWS_BEDROCK_REGION: AWS region for Mantle endpoint (default: us-east-1). + +Endpoint format: + https://bedrock-mantle.{region}.api.aws/v1 + +Supported models (examples): + - qwen.qwen3-32b-v1:0 + - qwen.qwen3-coder-30b-a3b-v1:0 + - qwen.qwen3-235b-a22b-2507-v1:0 + - amazon.nova-micro-v1:0 + - meta.llama3-2-1b-instruct-v1:0 + +Usage: + from rlm import RLM + + rlm = RLM( + backend="bedrock", + backend_kwargs={ + "model_name": "qwen.qwen3-coder-30b-a3b-v1:0", + # api_key and region are optional if env vars are set + }, + ) + result = rlm.completion("Your prompt here") + +See also: + https://docs.aws.amazon.com/bedrock/latest/userguide/bedrock-mantle.html + https://docs.aws.amazon.com/bedrock/latest/userguide/api-keys.html +""" + +import os +from collections import defaultdict +from typing import Any + +import openai +from dotenv import load_dotenv + +from rlm.clients.base_lm import BaseLM +from rlm.core.types import ModelUsageSummary, UsageSummary + +load_dotenv() + +# Default environment variables +DEFAULT_BEDROCK_API_KEY = os.getenv("AWS_BEDROCK_API_KEY") +DEFAULT_BEDROCK_REGION = os.getenv("AWS_BEDROCK_REGION", "us-east-1") + + +def _build_mantle_base_url(region: str) -> str: + """Build the Bedrock Mantle endpoint URL for a given region.""" + return f"https://bedrock-mantle.{region}.api.aws/v1" + + +class BedrockClient(BaseLM): + """LM Client for AWS Bedrock via Project Mantle (OpenAI-compatible API). + + Bedrock's Project Mantle provides a native OpenAI-compatible endpoint + that accepts Bedrock API keys as Bearer tokens. This client uses the + standard OpenAI SDK under the hood, configured to point at Mantle. + + Args: + api_key: Bedrock API key. Falls back to AWS_BEDROCK_API_KEY env var. + model_name: Model ID (e.g., "qwen.qwen3-coder-30b-a3b-v1:0"). + region: AWS region for Mantle endpoint. Falls back to AWS_BEDROCK_REGION + env var, then defaults to "us-east-1". + base_url: Override the Mantle endpoint URL. If not provided, it's + constructed from the region. + max_tokens: Maximum tokens for completion (default: 32768). + **kwargs: Additional arguments passed to BaseLM. + + Raises: + ValueError: If no API key is provided and AWS_BEDROCK_API_KEY is not set. + """ + + def __init__( + self, + api_key: str | None = None, + model_name: str | None = None, + region: str | None = None, + base_url: str | None = None, + max_tokens: int = 32768, + **kwargs, + ): + super().__init__(model_name=model_name, **kwargs) + + # Resolve API key + if api_key is None: + api_key = DEFAULT_BEDROCK_API_KEY + if api_key is None: + raise ValueError( + "Bedrock API key is required. " + "Set AWS_BEDROCK_API_KEY environment variable or pass api_key parameter." + ) + + # Resolve region and base URL + region = region or DEFAULT_BEDROCK_REGION + if base_url is None: + base_url = _build_mantle_base_url(region) + + # Initialize OpenAI clients pointing to Mantle + self.client = openai.OpenAI(api_key=api_key, base_url=base_url) + self.async_client = openai.AsyncOpenAI(api_key=api_key, base_url=base_url) + self.model_name = model_name + self.max_tokens = max_tokens + self.region = region + self.base_url = base_url + + # Per-model usage tracking + self.model_call_counts: dict[str, int] = defaultdict(int) + self.model_input_tokens: dict[str, int] = defaultdict(int) + self.model_output_tokens: dict[str, int] = defaultdict(int) + self.model_total_tokens: dict[str, int] = defaultdict(int) + self.last_prompt_tokens: int = 0 + self.last_completion_tokens: int = 0 + + def completion(self, prompt: str | list[dict[str, Any]], model: str | None = None) -> str: + """Generate a completion using Bedrock via Mantle. + + Args: + prompt: Either a string or a list of message dicts with "role" and "content". + model: Override the model for this request. + + Returns: + The generated text response. + + Raises: + ValueError: If no model is specified. + """ + messages = self._prepare_messages(prompt) + + model = model or self.model_name + if not model: + raise ValueError("Model name is required for Bedrock client.") + + response = self.client.chat.completions.create( + model=model, + messages=messages, + max_tokens=self.max_tokens, + ) + self._track_cost(response, model) + return response.choices[0].message.content + + async def acompletion( + self, prompt: str | list[dict[str, Any]], model: str | None = None + ) -> str: + """Generate a completion asynchronously using Bedrock via Mantle. + + Args: + prompt: Either a string or a list of message dicts with "role" and "content". + model: Override the model for this request. + + Returns: + The generated text response. + + Raises: + ValueError: If no model is specified. + """ + messages = self._prepare_messages(prompt) + + model = model or self.model_name + if not model: + raise ValueError("Model name is required for Bedrock client.") + + response = await self.async_client.chat.completions.create( + model=model, + messages=messages, + max_tokens=self.max_tokens, + ) + self._track_cost(response, model) + return response.choices[0].message.content + + def _prepare_messages(self, prompt: str | list[dict[str, Any]]) -> list[dict[str, Any]]: + """Convert prompt to OpenAI message format. + + Args: + prompt: Either a string or a list of message dicts. + + Returns: + List of message dicts in OpenAI format. + + Raises: + ValueError: If prompt type is not supported. + """ + if isinstance(prompt, str): + return [{"role": "user", "content": prompt}] + elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt): + return prompt + else: + raise ValueError(f"Invalid prompt type: {type(prompt)}") + + def _track_cost(self, response: openai.types.chat.ChatCompletion, model: str) -> None: + """Track token usage for cost monitoring. + + Args: + response: The completion response from OpenAI SDK. + model: The model name used for this request. + """ + self.model_call_counts[model] += 1 + + usage = getattr(response, "usage", None) + if usage is not None: + self.model_input_tokens[model] += usage.prompt_tokens + self.model_output_tokens[model] += usage.completion_tokens + self.model_total_tokens[model] += usage.total_tokens + self.last_prompt_tokens = usage.prompt_tokens + self.last_completion_tokens = usage.completion_tokens + else: + # Some responses may not include usage; track call count only + self.last_prompt_tokens = 0 + self.last_completion_tokens = 0 + + def get_usage_summary(self) -> UsageSummary: + """Get aggregated usage summary for all models. + + Returns: + UsageSummary with per-model token counts and call counts. + """ + model_summaries = {} + for model in self.model_call_counts: + model_summaries[model] = ModelUsageSummary( + total_calls=self.model_call_counts[model], + total_input_tokens=self.model_input_tokens[model], + total_output_tokens=self.model_output_tokens[model], + ) + return UsageSummary(model_usage_summaries=model_summaries) + + def get_last_usage(self) -> ModelUsageSummary: + """Get usage summary for the last API call. + + Returns: + ModelUsageSummary for the most recent completion. + """ + return ModelUsageSummary( + total_calls=1, + total_input_tokens=self.last_prompt_tokens, + total_output_tokens=self.last_completion_tokens, + ) diff --git a/rlm/core/types.py b/rlm/core/types.py index f20474d4..eb375443 100644 --- a/rlm/core/types.py +++ b/rlm/core/types.py @@ -12,6 +12,7 @@ "anthropic", "azure_openai", "gemini", + "bedrock", ] EnvironmentType = Literal["local", "docker", "modal", "prime", "daytona"] diff --git a/tests/clients/test_bedrock.py b/tests/clients/test_bedrock.py new file mode 100644 index 00000000..77ebd1c3 --- /dev/null +++ b/tests/clients/test_bedrock.py @@ -0,0 +1,284 @@ +"""Tests for the AWS Bedrock client.""" + +import os +from unittest.mock import MagicMock, patch + +import pytest +from dotenv import load_dotenv + +from rlm.clients.bedrock import BedrockClient, _build_mantle_base_url +from rlm.core.types import ModelUsageSummary, UsageSummary + +load_dotenv() + +# Default test model - Qwen3 Coder for code generation tasks +TEST_MODEL = "qwen.qwen3-coder-30b-a3b-v1:0" + + +class TestBedrockClientUnit: + """Unit tests that don't require API calls.""" + + def test_build_mantle_base_url(self): + """Test Mantle URL construction for different regions.""" + assert _build_mantle_base_url("us-east-1") == "https://bedrock-mantle.us-east-1.api.aws/v1" + assert _build_mantle_base_url("eu-west-1") == "https://bedrock-mantle.eu-west-1.api.aws/v1" + assert _build_mantle_base_url("ap-northeast-1") == "https://bedrock-mantle.ap-northeast-1.api.aws/v1" + + def test_init_with_explicit_params(self): + """Test client initialization with explicit parameters.""" + with patch("rlm.clients.bedrock.openai.OpenAI"): + with patch("rlm.clients.bedrock.openai.AsyncOpenAI"): + client = BedrockClient( + api_key="test-key", + model_name=TEST_MODEL, + region="us-west-2", + ) + assert client.model_name == TEST_MODEL + assert client.region == "us-west-2" + assert client.base_url == "https://bedrock-mantle.us-west-2.api.aws/v1" + + def test_init_with_custom_base_url(self): + """Test client initialization with custom base URL override.""" + with patch("rlm.clients.bedrock.openai.OpenAI"): + with patch("rlm.clients.bedrock.openai.AsyncOpenAI"): + custom_url = "https://custom-endpoint.example.com/v1" + client = BedrockClient( + api_key="test-key", + model_name=TEST_MODEL, + base_url=custom_url, + ) + assert client.base_url == custom_url + + def test_init_default_region(self): + """Test client uses default region (us-east-1).""" + with patch("rlm.clients.bedrock.openai.OpenAI"): + with patch("rlm.clients.bedrock.openai.AsyncOpenAI"): + with patch.dict(os.environ, {}, clear=False): + # Clear region env var if set + os.environ.pop("AWS_BEDROCK_REGION", None) + client = BedrockClient(api_key="test-key", model_name=TEST_MODEL) + assert client.region == "us-east-1" + + def test_init_requires_api_key(self): + """Test client raises error when no API key provided.""" + with patch.dict(os.environ, {}, clear=True): + with patch("rlm.clients.bedrock.DEFAULT_BEDROCK_API_KEY", None): + with pytest.raises(ValueError, match="Bedrock API key is required"): + BedrockClient(api_key=None) + + def test_init_uses_env_api_key(self): + """Test client uses API key from environment variable.""" + with patch("rlm.clients.bedrock.openai.OpenAI") as mock_openai: + with patch("rlm.clients.bedrock.openai.AsyncOpenAI"): + with patch("rlm.clients.bedrock.DEFAULT_BEDROCK_API_KEY", "env-api-key"): + client = BedrockClient(model_name=TEST_MODEL) + # Verify OpenAI client was called with env key + mock_openai.assert_called_once() + call_kwargs = mock_openai.call_args[1] + assert call_kwargs["api_key"] == "env-api-key" + + def test_usage_tracking_initialization(self): + """Test that usage tracking is properly initialized.""" + with patch("rlm.clients.bedrock.openai.OpenAI"): + with patch("rlm.clients.bedrock.openai.AsyncOpenAI"): + client = BedrockClient(api_key="test-key", model_name=TEST_MODEL) + assert client.model_call_counts == {} + assert client.model_input_tokens == {} + assert client.model_output_tokens == {} + assert client.last_prompt_tokens == 0 + assert client.last_completion_tokens == 0 + + def test_get_usage_summary_empty(self): + """Test usage summary when no calls have been made.""" + with patch("rlm.clients.bedrock.openai.OpenAI"): + with patch("rlm.clients.bedrock.openai.AsyncOpenAI"): + client = BedrockClient(api_key="test-key", model_name=TEST_MODEL) + summary = client.get_usage_summary() + assert isinstance(summary, UsageSummary) + assert summary.model_usage_summaries == {} + + def test_get_last_usage(self): + """Test last usage returns correct format.""" + with patch("rlm.clients.bedrock.openai.OpenAI"): + with patch("rlm.clients.bedrock.openai.AsyncOpenAI"): + client = BedrockClient(api_key="test-key", model_name=TEST_MODEL) + client.last_prompt_tokens = 100 + client.last_completion_tokens = 50 + usage = client.get_last_usage() + assert isinstance(usage, ModelUsageSummary) + assert usage.total_calls == 1 + assert usage.total_input_tokens == 100 + assert usage.total_output_tokens == 50 + + def test_prepare_messages_string(self): + """Test _prepare_messages with string input.""" + with patch("rlm.clients.bedrock.openai.OpenAI"): + with patch("rlm.clients.bedrock.openai.AsyncOpenAI"): + client = BedrockClient(api_key="test-key", model_name=TEST_MODEL) + messages = client._prepare_messages("Hello world") + assert messages == [{"role": "user", "content": "Hello world"}] + + def test_prepare_messages_list(self): + """Test _prepare_messages with message list input.""" + with patch("rlm.clients.bedrock.openai.OpenAI"): + with patch("rlm.clients.bedrock.openai.AsyncOpenAI"): + client = BedrockClient(api_key="test-key", model_name=TEST_MODEL) + input_messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + ] + messages = client._prepare_messages(input_messages) + assert messages == input_messages + + def test_prepare_messages_invalid_type(self): + """Test _prepare_messages raises on invalid input.""" + with patch("rlm.clients.bedrock.openai.OpenAI"): + with patch("rlm.clients.bedrock.openai.AsyncOpenAI"): + client = BedrockClient(api_key="test-key", model_name=TEST_MODEL) + with pytest.raises(ValueError, match="Invalid prompt type"): + client._prepare_messages(12345) + + def test_completion_requires_model(self): + """Test completion raises when no model specified.""" + with patch("rlm.clients.bedrock.openai.OpenAI"): + with patch("rlm.clients.bedrock.openai.AsyncOpenAI"): + client = BedrockClient(api_key="test-key", model_name=None) + with pytest.raises(ValueError, match="Model name is required"): + client.completion("Hello") + + def test_completion_with_mocked_response(self): + """Test completion with mocked API response.""" + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Hello from Bedrock!" + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 5 + mock_response.usage.total_tokens = 15 + + with patch("rlm.clients.bedrock.openai.OpenAI") as mock_openai_class: + with patch("rlm.clients.bedrock.openai.AsyncOpenAI"): + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = mock_response + mock_openai_class.return_value = mock_client + + client = BedrockClient(api_key="test-key", model_name=TEST_MODEL) + result = client.completion("Hello") + + assert result == "Hello from Bedrock!" + assert client.model_call_counts[TEST_MODEL] == 1 + assert client.model_input_tokens[TEST_MODEL] == 10 + assert client.model_output_tokens[TEST_MODEL] == 5 + + def test_completion_tracks_usage_across_calls(self): + """Test that usage is accumulated across multiple calls.""" + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response" + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 5 + mock_response.usage.total_tokens = 15 + + with patch("rlm.clients.bedrock.openai.OpenAI") as mock_openai_class: + with patch("rlm.clients.bedrock.openai.AsyncOpenAI"): + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = mock_response + mock_openai_class.return_value = mock_client + + client = BedrockClient(api_key="test-key", model_name=TEST_MODEL) + client.completion("Hello 1") + client.completion("Hello 2") + client.completion("Hello 3") + + assert client.model_call_counts[TEST_MODEL] == 3 + assert client.model_input_tokens[TEST_MODEL] == 30 + assert client.model_output_tokens[TEST_MODEL] == 15 + + summary = client.get_usage_summary() + assert summary.model_usage_summaries[TEST_MODEL].total_calls == 3 + + +class TestBedrockClientIntegration: + """Integration tests that require a real Bedrock API key. + + These tests are skipped if AWS_BEDROCK_API_KEY is not set. + Set the environment variable to run integration tests: + + export AWS_BEDROCK_API_KEY=your-bedrock-api-key + pytest tests/clients/test_bedrock.py -v + """ + + @pytest.mark.skipif( + not os.environ.get("AWS_BEDROCK_API_KEY"), + reason="AWS_BEDROCK_API_KEY not set", + ) + def test_simple_completion(self): + """Test a simple completion with real API.""" + client = BedrockClient(model_name=TEST_MODEL) + result = client.completion("What is 2+2? Reply with just the number.") + assert "4" in result + + # Verify usage was tracked + usage = client.get_usage_summary() + assert TEST_MODEL in usage.model_usage_summaries + assert usage.model_usage_summaries[TEST_MODEL].total_calls == 1 + + @pytest.mark.skipif( + not os.environ.get("AWS_BEDROCK_API_KEY"), + reason="AWS_BEDROCK_API_KEY not set", + ) + def test_message_list_completion(self): + """Test completion with message list format.""" + client = BedrockClient(model_name=TEST_MODEL) + messages = [ + {"role": "system", "content": "You are a helpful math tutor."}, + {"role": "user", "content": "What is 5 * 5? Reply with just the number."}, + ] + result = client.completion(messages) + assert "25" in result + + @pytest.mark.skipif( + not os.environ.get("AWS_BEDROCK_API_KEY"), + reason="AWS_BEDROCK_API_KEY not set", + ) + @pytest.mark.asyncio + async def test_async_completion(self): + """Test async completion.""" + client = BedrockClient(model_name=TEST_MODEL) + result = await client.acompletion("What is 3+3? Reply with just the number.") + assert "6" in result + + @pytest.mark.skipif( + not os.environ.get("AWS_BEDROCK_API_KEY"), + reason="AWS_BEDROCK_API_KEY not set", + ) + def test_code_generation(self): + """Test code generation capability (important for RLM use case).""" + client = BedrockClient(model_name=TEST_MODEL) + result = client.completion( + "Write a Python function that returns the sum of a list of numbers. " + "Reply with only the function code, no explanation." + ) + assert "def" in result + assert "sum" in result.lower() or "return" in result + + +if __name__ == "__main__": + # Run integration tests directly + print("Running Bedrock integration tests...") + print(f"Using model: {TEST_MODEL}") + + test = TestBedrockClientIntegration() + + print("\n1. Testing simple completion...") + test.test_simple_completion() + print(" ✓ Passed") + + print("\n2. Testing message list completion...") + test.test_message_list_completion() + print(" ✓ Passed") + + print("\n3. Testing code generation...") + test.test_code_generation() + print(" ✓ Passed") + + print("\n✓ All integration tests passed!")