diff --git a/README.md b/README.md index 57259675c..f13da2eb7 100644 --- a/README.md +++ b/README.md @@ -221,6 +221,9 @@ pip install graphiti-core[falkordb,anthropic,google-genai] # Install with Amazon Neptune pip install graphiti-core[neptune] + +# Install with Amazon Bedrock +pip install graphiti-core[bedrock] ``` ## Default to Low Concurrency; LLM Provider 429 Rate Limit Errors @@ -544,6 +547,68 @@ graphiti = Graphiti( Ensure Ollama is running (`ollama serve`) and that you have pulled the models you want to use. +## Using Graphiti with Amazon Bedrock + +Graphiti supports Amazon Bedrock for LLM inference, embeddings, and reranking. Amazon Bedrock provides access to foundation models from leading AI companies. + +Install Graphiti with Amazon Bedrock support: + +```bash +uv add "graphiti-core[bedrock]" + +# or + +pip install "graphiti-core[bedrock]" +``` + +```python +from graphiti_core import Graphiti +from graphiti_core.llm_client.amazon_bedrock_client import AmazonBedrockLLMClient +from graphiti_core.llm_client.config import LLMConfig +from graphiti_core.embedder.amazon_bedrock import AmazonBedrockEmbedder, AmazonBedrockEmbedderConfig +from graphiti_core.cross_encoder.amazon_bedrock_reranker_client import AmazonBedrockRerankerClient + +# Configure Amazon Bedrock clients +llm_client = AmazonBedrockLLMClient( + config=LLMConfig( + model="us.anthropic.claude-sonnet-4-20250514-v1:0", + temperature=0.1, + max_tokens=1000, + ), + region="us-east-1", +) + +embedder_client = AmazonBedrockEmbedder( + config=AmazonBedrockEmbedderConfig( + model="amazon.titan-embed-text-v2:0", + region="us-east-1", + ) +) + +reranker_client = AmazonBedrockRerankerClient( + model="cohere.rerank-v3-5:0", + region="us-east-1", + max_results=10, +) + +# Initialize Graphiti with Amazon Bedrock clients +graphiti = Graphiti( + "bolt://localhost:7687", + "neo4j", + "password", + llm_client=llm_client, + embedder=embedder_client, + cross_encoder=reranker_client, +) + +# Now you can use Graphiti with Amazon Bedrock +``` + +**Key Points:** +- Requires AWS credentials configured (via AWS CLI, environment variables, or IAM roles) +- Different models are available in different AWS regions - check the [AWS Bedrock documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html) for availability +- See `examples/amazon-bedrock/` for a complete working example + ## Documentation - [Guides and API documentation](https://help.getzep.com/graphiti). diff --git a/examples/amazon-bedrock/.env.example b/examples/amazon-bedrock/.env.example new file mode 100644 index 000000000..93b5d33aa --- /dev/null +++ b/examples/amazon-bedrock/.env.example @@ -0,0 +1,20 @@ +# Neo4j Configuration +NEO4J_URI=bolt://localhost:7687 +NEO4J_USER=neo4j +NEO4J_PASSWORD=password + +# AWS Configuration +# AWS credentials can be configured via: +# 1. Environment variables (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) +# 2. AWS credentials file (~/.aws/credentials) +# 3. IAM roles (for EC2 instances) +AWS_REGION=us-east-1 + +# Amazon Bedrock Model Configuration +BEDROCK_LLM_MODEL=us.anthropic.claude-sonnet-4-20250514-v1:0 +BEDROCK_EMBEDDING_MODEL=amazon.titan-embed-text-v2:0 +BEDROCK_RERANKER_MODEL=cohere.rerank-v3-5:0 + +# Optional: Explicit AWS credentials (not recommended for production) +# AWS_ACCESS_KEY_ID=your_access_key_here +# AWS_SECRET_ACCESS_KEY=your_secret_key_here \ No newline at end of file diff --git a/examples/amazon-bedrock/README.md b/examples/amazon-bedrock/README.md new file mode 100644 index 000000000..d192901b6 --- /dev/null +++ b/examples/amazon-bedrock/README.md @@ -0,0 +1,179 @@ +# Amazon Bedrock + Neo4j Example + +This example demonstrates how to use Graphiti with Amazon Bedrock for LLM inference, embeddings, and reranking, combined with Neo4j as the graph database. + +## Features Demonstrated + +- **Amazon Bedrock LLM Client**: Uses Claude models for text generation and structured output +- **Amazon Bedrock Embedder**: Uses Titan embedding models for semantic search +- **Amazon Bedrock Reranker**: Uses Cohere or Amazon rerank models for result reranking +- **Neo4j Integration**: Stores and queries the knowledge graph +- **Hybrid Search**: Combines semantic similarity and BM25 text retrieval +- **Graph-based Reranking**: Reorders results based on graph distance + +## Prerequisites + +### 1. Neo4j Database + +Install and start Neo4j: +- Download [Neo4j Desktop](https://neo4j.com/download/) +- Create a new database with username `neo4j` and password `password` +- Start the database + +### 2. AWS Account and Bedrock Access + +You need: +- An AWS account with Bedrock access +- Model access enabled for the models you want to use: + - Claude models (e.g., `us.anthropic.claude-sonnet-4-20250514-v1:0`) + - Titan embedding models (e.g., `amazon.titan-embed-text-v2:0`) + - Rerank models (e.g., `cohere.rerank-v3-5:0`) + +### 3. AWS Credentials + +Configure AWS credentials using one of these methods: + +**Option 1: AWS CLI** +```bash +aws configure +``` + +**Option 2: Environment Variables** +```bash +export AWS_ACCESS_KEY_ID=your_access_key +export AWS_SECRET_ACCESS_KEY=your_secret_key +export AWS_REGION=us-east-1 +``` + +### 4. Python Dependencies + +Sync dependencies with Amazon Bedrock support: +```bash +uv sync --extra bedrock +``` + +## Setup + +1. **Navigate to the example directory:** + ```bash + cd examples/amazon-bedrock + ``` + +2. **Copy and configure environment variables:** + ```bash + cp .env.example .env + # Edit .env with your configuration + ``` + + +## Configuration + +### Environment Variables + +The example uses these environment variables (see `.env.example`): + +| Variable | Default | Description | +|----------|---------|-------------| +| `NEO4J_URI` | `bolt://localhost:7687` | Neo4j connection URI | +| `NEO4J_USER` | `neo4j` | Neo4j username | +| `NEO4J_PASSWORD` | `password` | Neo4j password | +| `AWS_REGION` | `us-east-1` | AWS region for Bedrock | +| `BEDROCK_LLM_MODEL` | `us.anthropic.claude-sonnet-4-20250514-v1:0` | Claude model for LLM | +| `BEDROCK_EMBEDDING_MODEL` | `amazon.titan-embed-text-v2:0` | Titan model for embeddings | +| `BEDROCK_RERANKER_MODEL` | `cohere.rerank-v3-5:0` | Rerank model | + +### Model Availability by Region + +Different Bedrock models are available in different AWS regions. Please check the [AWS Bedrock documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html) for the latest information on model availability by region. + +## Running the Example + +```bash +uv run python amazon_bedrock_neo4j.py +``` + +## What the Example Does + +1. **Initializes Graphiti** with Amazon Bedrock clients and Neo4j +2. **Adds Episodes** containing information about California politics +3. **Performs Hybrid Search** using semantic similarity and BM25 +4. **Demonstrates Reranking** using Amazon Bedrock reranker +5. **Shows Integration Details** including model and region information + +## Expected Output + +``` +Added episode: California Politics 0 (text) +Added episode: California Politics 1 (text) +Added episode: California Politics 2 (json) + +Searching for: 'Who was the California Attorney General?' + +Search Results: +UUID: [uuid] +Fact: Kamala Harris holds the position of Attorney General of California +Valid from: [timestamp] +--- + +... + +Reranking search results based on graph distance: +Using center node UUID: [uuid] +Reranker model: cohere.rerank-v3-5:0 + +Reranked Search Results (using Amazon Bedrock reranker): +UUID: [uuid] +Fact: Kamala Harris is the Attorney General of California +Valid from: [timestamp] +--- + +... + +Connection closed +``` + +## Troubleshooting + +### Common Issues + +**1. AWS Credentials Not Found** +``` +NoCredentialsError: Unable to locate credentials +``` +- Ensure AWS credentials are properly configured +- Check that your AWS profile has the correct permissions + +**2. Model Access Denied** +``` +AccessDeniedException: User is not authorized to perform: bedrock:InvokeModel +``` +- Request access to the specific models in the AWS Bedrock console +- Ensure your AWS account has Bedrock permissions + +**3. Model Not Available in Region** +``` +ValidationException: The model ID is not supported in this region +``` +- Check model availability in your selected region +- Update the region or model in your configuration + +**4. Neo4j Connection Failed** +``` +ServiceUnavailable: Failed to establish connection +``` +- Ensure Neo4j is running +- Check the connection URI, username, and password + +### Getting Help + +- Check the [Graphiti documentation](https://help.getzep.com/graphiti) +- Join the [Zep Discord server](https://discord.com/invite/W8Kw6bsgXQ) #graphiti channel +- Review AWS Bedrock documentation for model access and permissions + +## Next Steps + +- Explore different Amazon Bedrock models +- Try different AWS regions +- Experiment with custom entity types +- Integrate with your own data sources +- Scale up with larger datasets \ No newline at end of file diff --git a/examples/amazon-bedrock/amazon_bedrock_neo4j.py b/examples/amazon-bedrock/amazon_bedrock_neo4j.py new file mode 100644 index 000000000..c06ecf643 --- /dev/null +++ b/examples/amazon-bedrock/amazon_bedrock_neo4j.py @@ -0,0 +1,232 @@ +""" +Copyright 2025, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import json +import logging +import os +from datetime import datetime, timezone +from logging import INFO + +from dotenv import load_dotenv + +from graphiti_core import Graphiti +from graphiti_core.cross_encoder.amazon_bedrock_reranker_client import AmazonBedrockRerankerClient +from graphiti_core.embedder.amazon_bedrock import AmazonBedrockEmbedder, AmazonBedrockEmbedderConfig +from graphiti_core.llm_client.amazon_bedrock_client import AmazonBedrockLLMClient +from graphiti_core.llm_client.config import LLMConfig +from graphiti_core.nodes import EpisodeType + +################################################# +# CONFIGURATION +################################################# +# Set up logging and environment variables for +# connecting to Neo4j database and Amazon Bedrock +################################################# + +# Configure logging +logging.basicConfig( + level=INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', +) +logger = logging.getLogger(__name__) + +load_dotenv() + +# Neo4j connection parameters +# Make sure Neo4j Desktop is running with a local DBMS started +neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687') +neo4j_user = os.environ.get('NEO4J_USER', 'neo4j') +neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password') + +# AWS Bedrock configuration +aws_region = os.environ.get('AWS_REGION', 'us-east-1') +llm_model = os.environ.get('BEDROCK_LLM_MODEL', 'us.anthropic.claude-sonnet-4-20250514-v1:0') +embedding_model = os.environ.get('BEDROCK_EMBEDDING_MODEL', 'amazon.titan-embed-text-v2:0') +reranker_model = os.environ.get('BEDROCK_RERANKER_MODEL', 'cohere.rerank-v3-5:0') + + +async def main(): + ################################################# + # INITIALIZATION + ################################################# + # Connect to Neo4j and Amazon Bedrock, then set up + # Graphiti indices. This is required before using + # other Graphiti functionality + ################################################# + + # Create Amazon Bedrock clients + llm_client = AmazonBedrockLLMClient( + config=LLMConfig( + model=llm_model, + temperature=0.1, + max_tokens=1000, + ), + region=aws_region, + ) + + embedder_client = AmazonBedrockEmbedder( + config=AmazonBedrockEmbedderConfig( + model=embedding_model, + region=aws_region, + ) + ) + + reranker_client = AmazonBedrockRerankerClient( + model=reranker_model, + region=aws_region, + max_results=10, + ) + + # Initialize Graphiti with Neo4j connection and Amazon Bedrock clients + graphiti = Graphiti( + neo4j_uri, + neo4j_user, + neo4j_password, + llm_client=llm_client, + embedder=embedder_client, + cross_encoder=reranker_client, + ) + + try: + ################################################# + # ADDING EPISODES + ################################################# + # Episodes are the primary units of information + # in Graphiti. They can be text or structured JSON + # and are automatically processed to extract entities + # and relationships. + ################################################# + + # Example: Add Episodes + # Episodes list containing both text and JSON episodes + episodes = [ + { + 'content': 'Kamala Harris is the Attorney General of California. She was previously ' + 'the district attorney for San Francisco.', + 'type': EpisodeType.text, + 'description': 'podcast transcript', + }, + { + 'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017', + 'type': EpisodeType.text, + 'description': 'podcast transcript', + }, + { + 'content': { + 'name': 'Gavin Newsom', + 'position': 'Governor', + 'state': 'California', + 'previous_role': 'Lieutenant Governor', + 'previous_location': 'San Francisco', + }, + 'type': EpisodeType.json, + 'description': 'podcast metadata', + }, + ] + + # Add episodes to the graph + for i, episode in enumerate(episodes): + await graphiti.add_episode( + name=f'California Politics {i}', + episode_body=( + episode['content'] + if isinstance(episode['content'], str) + else json.dumps(episode['content']) + ), + source=episode['type'], + source_description=episode['description'], + reference_time=datetime.now(timezone.utc), + ) + print(f'Added episode: California Politics {i} ({episode["type"].value})') + + ################################################# + # BASIC SEARCH + ################################################# + # The simplest way to retrieve relationships (edges) + # from Graphiti is using the search method, which + # performs a hybrid search combining semantic + # similarity and BM25 text retrieval. + ################################################# + + # Perform a hybrid search combining semantic similarity and BM25 retrieval + print("\nSearching for: 'Who was the California Attorney General?'") + results = await graphiti.search('Who was the California Attorney General?') + + # Print search results + print('\nSearch Results:') + for result in results: + print(f'UUID: {result.uuid}') + print(f'Fact: {result.fact}') + if hasattr(result, 'valid_at') and result.valid_at: + print(f'Valid from: {result.valid_at}') + if hasattr(result, 'invalid_at') and result.invalid_at: + print(f'Valid until: {result.invalid_at}') + print('---') + + ################################################# + # CENTER NODE SEARCH WITH RERANKING + ################################################# + # For more contextually relevant results, you can + # use a center node to rerank search results based + # on their graph distance to a specific node. + # This example demonstrates the Amazon Bedrock + # reranker in action. + ################################################# + + # Use the top search result's UUID as the center node for reranking + if results and len(results) > 0: + # Get the source node UUID from the top result + center_node_uuid = results[0].source_node_uuid + + print('\nReranking search results based on graph distance:') + print(f'Using center node UUID: {center_node_uuid}') + print(f'Reranker model: {reranker_model}') + + reranked_results = await graphiti.search( + 'Who was the California Attorney General?', + center_node_uuid=center_node_uuid, + ) + + # Print reranked search results + print('\nReranked Search Results (using Amazon Bedrock reranker):') + for result in reranked_results: + print(f'UUID: {result.uuid}') + print(f'Fact: {result.fact}') + if hasattr(result, 'valid_at') and result.valid_at: + print(f'Valid from: {result.valid_at}') + if hasattr(result, 'invalid_at') and result.invalid_at: + print(f'Valid until: {result.invalid_at}') + print('---') + else: + print('No results found in the initial search to use as center node.') + + finally: + ################################################# + # CLEANUP + ################################################# + # Always close the connection to Neo4j when + # finished to properly release resources + ################################################# + + # Close the connection + await graphiti.close() + print('\nConnection closed') + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/graphiti_core/cross_encoder/amazon_bedrock_reranker_client.py b/graphiti_core/cross_encoder/amazon_bedrock_reranker_client.py new file mode 100644 index 000000000..8d8aa901e --- /dev/null +++ b/graphiti_core/cross_encoder/amazon_bedrock_reranker_client.py @@ -0,0 +1,123 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import logging +from typing import TYPE_CHECKING, Literal + +from .client import CrossEncoderClient + +if TYPE_CHECKING: + import boto3 +else: + try: + import boto3 + except ImportError: + raise ImportError( + 'boto3 is required for AmazonBedrockRerankerClient. ' + 'Install it with: pip install graphiti-core[bedrock]' + ) from None + +logger = logging.getLogger(__name__) + +BedrockRerankModel = Literal['cohere.rerank-v3-5:0', 'amazon.rerank-v1:0'] + +DEFAULT_MODEL: BedrockRerankModel = 'cohere.rerank-v3-5:0' + +# Model support by region as of December 2025 +MODEL_REGIONS = { + 'amazon.rerank-v1:0': ['ap-northeast-1', 'ca-central-1', 'eu-central-1', 'us-west-2'], + 'cohere.rerank-v3-5:0': [ + 'ap-northeast-1', + 'ca-central-1', + 'eu-central-1', + 'us-east-1', + 'us-west-2', + ], +} + + +class AmazonBedrockRerankerClient(CrossEncoderClient): + def __init__( + self, + model: BedrockRerankModel = DEFAULT_MODEL, + region: str = 'us-east-1', + max_results: int = 100, + ): + # Validate region supports the model + if region not in MODEL_REGIONS[model]: + supported_regions = ', '.join(MODEL_REGIONS[model]) + raise ValueError( + f'Model {model} is not supported in region {region}. Supported regions: {supported_regions}' + ) + + self.model = model + self.region = region + self.max_results = max_results + self.client = boto3.client('bedrock-agent-runtime', region_name=region) + + async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]: + if not passages: + return [] + + sources = [ + { + 'type': 'INLINE', + 'inlineDocumentSource': { + 'type': 'TEXT', + 'textDocument': { + 'text': passage, + }, + }, + } + for passage in passages + ] + + model_arn = f'arn:aws:bedrock:{self.region}::foundation-model/{self.model}' + + try: + # Use executor to run sync boto3 call in async context + loop = asyncio.get_event_loop() + response = await loop.run_in_executor( + None, + lambda: self.client.rerank( + queries=[{'type': 'TEXT', 'textQuery': {'text': query}}], + sources=sources, + rerankingConfiguration={ + 'type': 'BEDROCK_RERANKING_MODEL', + 'bedrockRerankingConfiguration': { + 'numberOfResults': min(self.max_results, len(passages)), + 'modelConfiguration': { + 'modelArn': model_arn, + }, + }, + }, + ), + ) + + # Extract results and map back to original passages + results = [] + for result in response.get('results', []): + index = result['index'] + relevance_score = result['relevanceScore'] + passage = passages[index] + results.append((passage, relevance_score)) + + return results + + except Exception as e: + logger.error(f'Error in Bedrock reranking: {e}') + raise diff --git a/graphiti_core/embedder/amazon_bedrock.py b/graphiti_core/embedder/amazon_bedrock.py new file mode 100644 index 000000000..678435069 --- /dev/null +++ b/graphiti_core/embedder/amazon_bedrock.py @@ -0,0 +1,82 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import json +import logging +from collections.abc import Iterable +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import boto3 +else: + try: + import boto3 + except ImportError: + raise ImportError( + 'boto3 is required for AmazonBedrockEmbedder. ' + 'Install it with: pip install graphiti-core[bedrock]' + ) from None + +from .client import EmbedderClient, EmbedderConfig + +logger = logging.getLogger(__name__) + +DEFAULT_EMBEDDING_MODEL = 'amazon.titan-embed-text-v2:0' + + +class AmazonBedrockEmbedderConfig(EmbedderConfig): + model: str = DEFAULT_EMBEDDING_MODEL + region: str = 'us-east-1' + + +class AmazonBedrockEmbedder(EmbedderClient): + def __init__(self, config: AmazonBedrockEmbedderConfig | None = None): + self.config = config or AmazonBedrockEmbedderConfig() + self.client = boto3.client('bedrock-runtime', region_name=self.config.region) + + async def create( + self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]] + ) -> list[float]: + if isinstance(input_data, str): + text = input_data + elif isinstance(input_data, list): + text = ' '.join(str(item) for item in input_data) + else: + text = str(input_data) + + body = json.dumps({'inputText': text}) + + try: + response = self.client.invoke_model( + modelId=self.config.model, + body=body, + accept='application/json', + contentType='application/json', + ) + + response_body = json.loads(response['body'].read().decode('utf-8')) + return response_body['embedding'] + + except Exception as e: + logger.error(f'Bedrock embedding failed: {e}') + raise + + async def create_batch(self, input_data_list: list[str]) -> list[list[float]]: + embeddings = [] + for text in input_data_list: + embedding = await self.create(text) + embeddings.append(embedding) + return embeddings diff --git a/graphiti_core/llm_client/amazon_bedrock_client.py b/graphiti_core/llm_client/amazon_bedrock_client.py new file mode 100644 index 000000000..9b857e83a --- /dev/null +++ b/graphiti_core/llm_client/amazon_bedrock_client.py @@ -0,0 +1,206 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import json +import logging +from typing import TYPE_CHECKING + +from pydantic import BaseModel + +from ..prompts.models import Message +from .client import LLMClient +from .config import DEFAULT_MAX_TOKENS, LLMConfig +from .errors import RateLimitError + +if TYPE_CHECKING: + import boto3 +else: + try: + import boto3 + except ImportError: + raise ImportError( + 'boto3 is required for AmazonBedrockLLMClient. ' + 'Install it with: pip install graphiti-core[bedrock]' + ) from None + +logger = logging.getLogger(__name__) + + +class AmazonBedrockLLMClient(LLMClient): + def __init__( + self, + config: LLMConfig | None = None, + max_tokens: int = DEFAULT_MAX_TOKENS, + region: str = 'us-east-1', + ): + super().__init__(config, cache=False) + + self.region = region + self.model = ( + config.model + if config and config.model + else 'us.anthropic.claude-sonnet-4-20250514-v1:0' + ) + self.temperature = config.temperature if config and config.temperature is not None else 0.7 + self.max_tokens = max_tokens + + self.client = boto3.client('bedrock-runtime', region_name=self.region) + + async def _generate_response( + self, + messages: list[Message], + response_model: type[BaseModel] | None = None, + max_tokens: int = DEFAULT_MAX_TOKENS, + model_size=None, + ): + # Convert Message objects to dict format + message_dicts = [msg.model_dump() for msg in messages] + + if response_model: + # Add JSON schema instruction for structured output + schema = response_model.model_json_schema() + schema_str = json.dumps(schema, indent=2) + + # Modify the last user message to include schema instruction + if message_dicts and message_dicts[-1]['role'] == 'user': + message_dicts[-1]['content'] += ( + f'\n\nPlease respond with valid JSON that follows this exact schema:\n{schema_str}\n\nIMPORTANT: Return ONLY valid JSON with no extra text, explanations, or markdown formatting.' + ) + + text_response = await self._invoke_bedrock_model( + model=self.model, + messages=message_dicts, + temperature=self.temperature, + max_tokens=max_tokens, + response_format='json', + ) + + try: + parsed_model = response_model.model_validate_json(text_response) + return parsed_model.model_dump() + except Exception as e: + logger.error(f'Failed to parse structured Bedrock response: {e}') + logger.error(f'Raw response: {text_response}') + raise + else: + text_response = await self._invoke_bedrock_model( + model=self.model, + messages=message_dicts, + temperature=self.temperature, + max_tokens=max_tokens, + response_format='text', + ) + return {'content': text_response} + + async def _invoke_bedrock_model( + self, + model: str, + messages: list[dict], + temperature: float, + max_tokens: int, + response_format: str, + ) -> str: + # Separate system prompt and user messages + system_prompt = None + final_messages = [m for m in messages if m['role'] != 'system'] + + for m in messages: + if m['role'] == 'system': + system_prompt = m['content'] + break + + body_dict = { + 'messages': final_messages, + 'temperature': temperature, + 'max_tokens': max_tokens, + 'anthropic_version': 'bedrock-2023-05-31', + } + + if system_prompt: + body_dict['system'] = system_prompt + + body = json.dumps(body_dict) + + try: + # Use executor to run sync boto3 call in async context + loop = asyncio.get_event_loop() + resp = await loop.run_in_executor( + None, + lambda: self.client.invoke_model( + modelId=model, + body=body, + accept='application/json', + contentType='application/json', + ), + ) + + data = json.loads(resp['body'].read().decode('utf-8')) + + if 'content' in data and data['content']: + text = data['content'][0].get('text', '') + elif 'outputText' in data: + text = data['outputText'] + else: + text = json.dumps(data) + + # Clean JSON response format + if response_format == 'json': + text = self._clean_json_response(text) + return text.strip() + + except Exception as e: + if 'throttling' in str(e).lower() or 'rate' in str(e).lower(): + raise RateLimitError(f'Rate limit exceeded: {e}') from e + logger.error(f'Bedrock model invocation failed: {e}') + raise + + def _clean_json_response(self, text: str) -> str: + """Clean JSON response from markdown formatting and extract JSON.""" + import re + + text = text.strip() + + logger.debug(f'Raw Bedrock response: {text[:500]}...') + + # Remove code blocks + text = re.sub(r'^```(?:json|JSON)?\s*', '', text, flags=re.MULTILINE) + text = re.sub(r'```\s*$', '', text, flags=re.MULTILINE) + + # Fix double braces issue + text = re.sub(r'^\{\{', '{', text) + text = re.sub(r'\}\}$', '}', text) + + # Find JSON object - look for first { and last } + start_idx = text.find('{') + if start_idx != -1: + # Find the matching closing brace + brace_count = 0 + end_idx = -1 + for i in range(start_idx, len(text)): + if text[i] == '{': + brace_count += 1 + elif text[i] == '}': + brace_count -= 1 + if brace_count == 0: + end_idx = i + 1 + break + + if end_idx != -1: + text = text[start_idx:end_idx] + + logger.debug(f'Cleaned JSON response: {text[:200]}...') + return text.strip() diff --git a/pyproject.toml b/pyproject.toml index d5ae9fb9a..ccf734038 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ voyageai = ["voyageai>=0.2.3"] neo4j-opensearch = ["boto3>=1.39.16", "opensearch-py>=3.0.0"] sentence-transformers = ["sentence-transformers>=3.2.1"] neptune = ["langchain-aws>=0.2.29", "opensearch-py>=3.0.0", "boto3>=1.39.16"] +bedrock = ["boto3>=1.39.16"] tracing = ["opentelemetry-api>=1.20.0", "opentelemetry-sdk>=1.20.0"] dev = [ "pyright>=1.1.404", diff --git a/tests/cross_encoder/test_amazon_bedrock_reranker_client.py b/tests/cross_encoder/test_amazon_bedrock_reranker_client.py new file mode 100644 index 000000000..1b71bc025 --- /dev/null +++ b/tests/cross_encoder/test_amazon_bedrock_reranker_client.py @@ -0,0 +1,284 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Running tests: pytest -xvs tests/cross_encoder/test_amazon_bedrock_reranker_client.py + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from graphiti_core.cross_encoder.amazon_bedrock_reranker_client import ( + DEFAULT_MODEL, + MODEL_REGIONS, + AmazonBedrockRerankerClient, +) + + +@pytest.fixture +def mock_boto3_client(): + """Fixture to mock the boto3 bedrock-agent-runtime client.""" + with patch('boto3.client') as mock_client: + mock_instance = MagicMock() + mock_client.return_value = mock_instance + yield mock_instance + + +@pytest.fixture +def mock_rerank_response(): + """Create a mock Bedrock rerank response.""" + return { + 'results': [ + {'index': 0, 'relevanceScore': 0.85}, + {'index': 1, 'relevanceScore': 0.65}, + {'index': 2, 'relevanceScore': 0.45}, + ] + } + + +@pytest.fixture +def reranker_client(mock_boto3_client): + """Fixture to create an AmazonBedrockRerankerClient with a mocked client.""" + client = AmazonBedrockRerankerClient(model=DEFAULT_MODEL, region='us-east-1', max_results=10) + client.client = mock_boto3_client + return client + + +class TestAmazonBedrockRerankerClientInitialization: + """Tests for AmazonBedrockRerankerClient initialization.""" + + def test_init_with_valid_model_region_combination(self): + """Test initialization with valid model and region combination.""" + client = AmazonBedrockRerankerClient( + model='cohere.rerank-v3-5:0', region='us-east-1', max_results=50 + ) + + assert client.model == 'cohere.rerank-v3-5:0' + assert client.region == 'us-east-1' + assert client.max_results == 50 + + def test_init_with_defaults(self): + """Test initialization with default values.""" + client = AmazonBedrockRerankerClient() + + assert client.model == DEFAULT_MODEL + assert client.region == 'us-east-1' + assert client.max_results == 100 + + def test_init_with_amazon_model_valid_region(self): + """Test initialization with Amazon model in valid region.""" + client = AmazonBedrockRerankerClient(model='amazon.rerank-v1:0', region='us-west-2') + + assert client.model == 'amazon.rerank-v1:0' + assert client.region == 'us-west-2' + + def test_init_with_invalid_model_region_combination(self): + """Test initialization with invalid model and region combination.""" + with pytest.raises(ValueError) as exc_info: + AmazonBedrockRerankerClient( + model='amazon.rerank-v1:0', + region='us-east-1', # Amazon model not available in us-east-1 + ) + + assert 'Model amazon.rerank-v1:0 is not supported in region us-east-1' in str( + exc_info.value + ) + assert 'us-west-2' in str(exc_info.value) # Should list supported regions + + def test_init_with_cohere_model_invalid_region(self): + """Test initialization with Cohere model in unsupported region.""" + with pytest.raises(ValueError) as exc_info: + AmazonBedrockRerankerClient( + model='cohere.rerank-v3-5:0', + region='ap-south-1', # Not in supported regions + ) + + assert 'Model cohere.rerank-v3-5:0 is not supported in region ap-south-1' in str( + exc_info.value + ) + + def test_model_regions_mapping(self): + """Test that MODEL_REGIONS mapping is correct.""" + # Verify Amazon model regions + amazon_regions = MODEL_REGIONS['amazon.rerank-v1:0'] + expected_amazon_regions = ['ap-northeast-1', 'ca-central-1', 'eu-central-1', 'us-west-2'] + assert amazon_regions == expected_amazon_regions + + # Verify Cohere model regions + cohere_regions = MODEL_REGIONS['cohere.rerank-v3-5:0'] + expected_cohere_regions = [ + 'ap-northeast-1', + 'ca-central-1', + 'eu-central-1', + 'us-east-1', + 'us-west-2', + ] + assert cohere_regions == expected_cohere_regions + + +class TestAmazonBedrockRerankerClientRanking: + """Tests for AmazonBedrockRerankerClient rank method.""" + + @pytest.mark.asyncio + async def test_rank_basic_functionality( + self, reranker_client, mock_boto3_client, mock_rerank_response + ): + """Test basic ranking functionality.""" + # Setup + query = 'What is machine learning?' + passages = [ + 'Machine learning is a subset of artificial intelligence.', + 'Python is a programming language.', + 'Deep learning uses neural networks for pattern recognition.', + ] + + with patch('asyncio.get_event_loop') as mock_loop: + mock_executor = AsyncMock() + mock_executor.return_value = mock_rerank_response + mock_loop.return_value.run_in_executor = mock_executor + + result = await reranker_client.rank(query, passages) + + # Assertions + assert len(result) == 3 + assert all(isinstance(item, tuple) for item in result) + assert all( + isinstance(passage, str) and isinstance(score, float) for passage, score in result + ) + + # Check that results are sorted by relevance score (descending) + scores = [score for _, score in result] + assert scores == sorted(scores, reverse=True) + + # Verify specific results based on mock response + assert result[0][1] == 0.85 # Highest score + assert result[1][1] == 0.65 # Medium score + assert result[2][1] == 0.45 # Lowest score + + # Verify executor was called + mock_executor.assert_called_once() + + @pytest.mark.asyncio + async def test_rank_empty_passages(self, reranker_client): + """Test ranking with empty passages list.""" + query = 'Test query' + passages = [] + + result = await reranker_client.rank(query, passages) + + assert result == [] + + @pytest.mark.asyncio + async def test_rank_single_passage(self, reranker_client, mock_boto3_client): + """Test ranking with a single passage.""" + # Setup single result response + single_response = {'results': [{'index': 0, 'relevanceScore': 0.75}]} + + query = 'Test query' + passages = ['Single test passage'] + + with patch('asyncio.get_event_loop') as mock_loop: + mock_executor = AsyncMock() + mock_executor.return_value = single_response + mock_loop.return_value.run_in_executor = mock_executor + + result = await reranker_client.rank(query, passages) + + assert len(result) == 1 + assert result[0][0] == 'Single test passage' + assert result[0][1] == 0.75 + + @pytest.mark.asyncio + async def test_rank_api_error_handling(self, reranker_client, mock_boto3_client): + """Test handling of API errors.""" + query = 'Test query' + passages = ['Passage 1', 'Passage 2'] + + with patch('asyncio.get_event_loop') as mock_loop: + mock_executor = AsyncMock() + mock_executor.side_effect = Exception('Bedrock API error') + mock_loop.return_value.run_in_executor = mock_executor + + with pytest.raises(Exception) as exc_info: + await reranker_client.rank(query, passages) + + assert 'Bedrock API error' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_rank_empty_results_response(self, reranker_client, mock_boto3_client): + """Test handling of empty results in response.""" + empty_response = {'results': []} + + query = 'Test query' + passages = ['Test passage'] + + with patch('asyncio.get_event_loop') as mock_loop: + mock_executor = AsyncMock() + mock_executor.return_value = empty_response + mock_loop.return_value.run_in_executor = mock_executor + + result = await reranker_client.rank(query, passages) + + assert result == [] + + @pytest.mark.asyncio + async def test_rank_missing_results_key(self, reranker_client, mock_boto3_client): + """Test handling of response missing results key.""" + invalid_response = {'other_key': 'value'} + + query = 'Test query' + passages = ['Test passage'] + + with patch('asyncio.get_event_loop') as mock_loop: + mock_executor = AsyncMock() + mock_executor.return_value = invalid_response + mock_loop.return_value.run_in_executor = mock_executor + + result = await reranker_client.rank(query, passages) + + # Should handle gracefully and return empty list + assert result == [] + + +class TestAmazonBedrockRerankerClientConfiguration: + """Tests for AmazonBedrockRerankerClient configuration.""" + + def test_different_models_and_regions(self): + """Test initialization with different valid model/region combinations.""" + # Test Cohere model in different regions + for region in MODEL_REGIONS['cohere.rerank-v3-5:0']: + client = AmazonBedrockRerankerClient(model='cohere.rerank-v3-5:0', region=region) + assert client.model == 'cohere.rerank-v3-5:0' + assert client.region == region + + # Test Amazon model in different regions + for region in MODEL_REGIONS['amazon.rerank-v1:0']: + client = AmazonBedrockRerankerClient(model='amazon.rerank-v1:0', region=region) + assert client.model == 'amazon.rerank-v1:0' + assert client.region == region + + def test_max_results_configuration(self): + """Test different max_results configurations.""" + # Test default + client = AmazonBedrockRerankerClient() + assert client.max_results == 100 + + # Test custom value + client = AmazonBedrockRerankerClient(max_results=50) + assert client.max_results == 50 + + +if __name__ == '__main__': + pytest.main(['-v', 'test_amazon_bedrock_reranker_client.py']) diff --git a/tests/cross_encoder/test_amazon_bedrock_reranker_client_int.py b/tests/cross_encoder/test_amazon_bedrock_reranker_client_int.py new file mode 100644 index 000000000..7f9be8d61 --- /dev/null +++ b/tests/cross_encoder/test_amazon_bedrock_reranker_client_int.py @@ -0,0 +1,222 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Running tests: pytest -xvs tests/cross_encoder/test_amazon_bedrock_reranker_client_int.py +# Requires: AWS credentials configured and Bedrock model access + +import os +from pathlib import Path + +import pytest + +from graphiti_core.cross_encoder.amazon_bedrock_reranker_client import ( + AmazonBedrockRerankerClient, +) + + +def _has_aws_credentials(): + """Check if AWS credentials are available.""" + # Check environment variables + if os.getenv('AWS_ACCESS_KEY_ID') or os.getenv('AWS_PROFILE'): + return True + + # Check for default credentials file + credentials_file = Path.home() / '.aws' / 'credentials' + return credentials_file.exists() + + +# Skip all tests if AWS credentials not available +pytestmark = pytest.mark.skipif(not _has_aws_credentials(), reason='AWS credentials not configured') + + +@pytest.fixture +def cohere_reranker(): + """Create a real AmazonBedrockRerankerClient with Cohere model.""" + return AmazonBedrockRerankerClient( + model='cohere.rerank-v3-5:0', region='us-east-1', max_results=10 + ) + + +@pytest.fixture +def amazon_reranker(): + """Create a real AmazonBedrockRerankerClient with Amazon model.""" + return AmazonBedrockRerankerClient( + model='amazon.rerank-v1:0', + region='us-west-2', # Amazon model available in us-west-2 + max_results=10, + ) + + +class TestAmazonBedrockRerankerClientIntegration: + """Integration tests for AmazonBedrockRerankerClient with real AWS Bedrock.""" + + @pytest.mark.asyncio + async def test_cohere_rerank_basic_functionality(self, cohere_reranker): + """Test basic reranking functionality with Cohere model.""" + query = 'What is machine learning?' + passages = [ + 'Machine learning is a subset of artificial intelligence that focuses on algorithms.', + 'Python is a popular programming language used for web development.', + 'Deep learning uses neural networks to learn patterns in data.', + 'JavaScript is commonly used for frontend web development.', + 'Supervised learning requires labeled training data.', + ] + + result = await cohere_reranker.rank(query, passages) + + # Verify basic structure + assert isinstance(result, list) + assert len(result) <= len(passages) + assert all(isinstance(item, tuple) for item in result) + assert all(len(item) == 2 for item in result) + assert all( + isinstance(passage, str) and isinstance(score, float) for passage, score in result + ) + + # Verify scores are in valid range + scores = [score for _, score in result] + assert all(0.0 <= score <= 1.0 for score in scores) + + # Verify results are sorted by relevance (descending) + assert scores == sorted(scores, reverse=True) + + # ML-related passages should rank higher + top_passage = result[0][0] + assert any( + word in top_passage.lower() for word in ['machine', 'learning', 'neural', 'supervised'] + ) + + @pytest.mark.asyncio + async def test_amazon_rerank_basic_functionality(self, amazon_reranker): + """Test basic reranking functionality with Amazon model.""" + query = 'Python programming' + passages = [ + 'Python is a high-level programming language known for its simplicity.', + 'The snake is a reptile that moves by slithering.', + 'Programming languages help developers create software applications.', + 'Reptiles are cold-blooded animals that lay eggs.', + ] + + result = await amazon_reranker.rank(query, passages) + + # Verify basic structure + assert isinstance(result, list) + assert len(result) <= len(passages) + assert all(isinstance(item, tuple) for item in result) + + # Programming-related passages should rank higher + top_passage = result[0][0] + assert any(word in top_passage.lower() for word in ['python', 'programming', 'language']) + + @pytest.mark.asyncio + async def test_rerank_empty_passages(self, cohere_reranker): + """Test reranking with empty passages list.""" + query = 'Test query' + passages = [] + + result = await cohere_reranker.rank(query, passages) + assert result == [] + + @pytest.mark.asyncio + async def test_rerank_single_passage(self, cohere_reranker): + """Test reranking with single passage.""" + query = 'artificial intelligence' + passages = ['AI is transforming various industries with automation.'] + + result = await cohere_reranker.rank(query, passages) + + assert len(result) == 1 + assert result[0][0] == passages[0] + assert isinstance(result[0][1], float) + assert 0.0 <= result[0][1] <= 1.0 + + @pytest.mark.asyncio + async def test_rerank_relevance_ordering(self, cohere_reranker): + """Test that more relevant passages get higher scores.""" + query = 'climate change effects' + passages = [ + 'Global warming is causing ice caps to melt rapidly.', # Highly relevant + 'Climate change leads to extreme weather patterns.', # Highly relevant + 'The recipe for chocolate cake requires flour and eggs.', # Not relevant + 'Rising sea levels threaten coastal communities.', # Relevant + 'My favorite color is blue and I like painting.', # Not relevant + ] + + result = await cohere_reranker.rank(query, passages) + + # Top results should be climate-related + top_3_passages = [passage for passage, _ in result[:3]] + climate_passages = [ + p + for p in top_3_passages + if any(word in p.lower() for word in ['climate', 'warming', 'sea', 'weather', 'ice']) + ] + + # At least 2 of top 3 should be climate-related + assert len(climate_passages) >= 2 + + @pytest.mark.asyncio + async def test_max_results_parameter(self, cohere_reranker): + """Test that max_results parameter is respected.""" + # Set max_results to 3 + reranker = AmazonBedrockRerankerClient( + model='cohere.rerank-v3-5:0', region='us-east-1', max_results=3 + ) + + query = 'technology' + passages = [ + 'Artificial intelligence is advancing rapidly.', + 'Smartphones have changed communication.', + 'Electric cars are becoming more popular.', + 'Social media connects people globally.', + 'Renewable energy is the future.', + ] + + result = await reranker.rank(query, passages) + + # Should return at most 3 results + assert len(result) <= 3 + + @pytest.mark.asyncio + async def test_different_query_types(self, cohere_reranker): + """Test reranking with different types of queries.""" + passages = [ + 'The capital of France is Paris.', + 'Machine learning algorithms process data.', + 'Cooking pasta requires boiling water.', + 'Paris is known for the Eiffel Tower.', + ] + + # Factual query + factual_result = await cohere_reranker.rank('What is the capital of France?', passages) + + # Technical query + technical_result = await cohere_reranker.rank('machine learning data processing', passages) + + # Both should return valid results with different rankings + assert len(factual_result) > 0 + assert len(technical_result) > 0 + + # Top results should be different for different queries + factual_top = factual_result[0][0] + technical_top = technical_result[0][0] + + assert 'France' in factual_top or 'Paris' in factual_top + assert 'machine' in technical_top.lower() or 'learning' in technical_top.lower() + + +if __name__ == '__main__': + pytest.main(['-v', 'test_amazon_bedrock_reranker_client_int.py']) diff --git a/tests/embedder/test_amazon_bedrock.py b/tests/embedder/test_amazon_bedrock.py new file mode 100644 index 000000000..7ff4b1270 --- /dev/null +++ b/tests/embedder/test_amazon_bedrock.py @@ -0,0 +1,243 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Running tests: pytest -xvs tests/embedder/test_amazon_bedrock.py + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from graphiti_core.embedder.amazon_bedrock import ( + DEFAULT_EMBEDDING_MODEL, + AmazonBedrockEmbedder, + AmazonBedrockEmbedderConfig, +) + + +@pytest.fixture +def mock_boto3_client(): + """Create a mocked boto3 bedrock-runtime client.""" + with patch('boto3.client') as mock_client: + mock_instance = MagicMock() + mock_client.return_value = mock_instance + yield mock_instance + + +@pytest.fixture +def mock_bedrock_response(): + """Create a mock Bedrock embeddings response.""" + mock_response = {'body': MagicMock()} + mock_response_body = {'embedding': [0.1] * 1024} + mock_response['body'].read.return_value.decode.return_value = json.dumps(mock_response_body) + return mock_response + + +@pytest.fixture +def bedrock_embedder(mock_boto3_client): + """Create an AmazonBedrockEmbedder with a mocked client.""" + config = AmazonBedrockEmbedderConfig(model=DEFAULT_EMBEDDING_MODEL, region='us-east-1') + embedder = AmazonBedrockEmbedder(config=config) + embedder.client = mock_boto3_client + return embedder + + +class TestAmazonBedrockEmbedderInitialization: + """Tests for AmazonBedrockEmbedder initialization.""" + + def test_init_with_config(self): + """Test initialization with a config object.""" + config = AmazonBedrockEmbedderConfig(model='custom.model:0', region='eu-west-1') + embedder = AmazonBedrockEmbedder(config=config) + + assert embedder.config.model == 'custom.model:0' + assert embedder.config.region == 'eu-west-1' + + def test_init_without_config(self): + """Test initialization without a config uses defaults.""" + embedder = AmazonBedrockEmbedder() + + assert embedder.config.model == DEFAULT_EMBEDDING_MODEL + assert embedder.config.region == 'us-east-1' + + def test_config_defaults(self): + """Test that config has correct default values.""" + config = AmazonBedrockEmbedderConfig() + + assert config.model == DEFAULT_EMBEDDING_MODEL + assert config.region == 'us-east-1' + + +class TestAmazonBedrockEmbedderCreate: + """Tests for AmazonBedrockEmbedder create method.""" + + @pytest.mark.asyncio + async def test_create_with_string_input( + self, bedrock_embedder, mock_boto3_client, mock_bedrock_response + ): + """Test create method with string input.""" + # Setup + mock_boto3_client.invoke_model.return_value = mock_bedrock_response + input_text = 'Test input string' + + # Call method + result = await bedrock_embedder.create(input_text) + + # Verify API call + mock_boto3_client.invoke_model.assert_called_once() + call_args = mock_boto3_client.invoke_model.call_args + + assert call_args[1]['modelId'] == DEFAULT_EMBEDDING_MODEL + assert call_args[1]['accept'] == 'application/json' + assert call_args[1]['contentType'] == 'application/json' + + # Verify request body + body = json.loads(call_args[1]['body']) + assert body['inputText'] == input_text + + # Verify result + assert isinstance(result, list) + assert len(result) == 1024 + assert all(isinstance(x, float) for x in result) + + @pytest.mark.asyncio + async def test_create_with_list_input( + self, bedrock_embedder, mock_boto3_client, mock_bedrock_response + ): + """Test create method with list of strings input.""" + # Setup + mock_boto3_client.invoke_model.return_value = mock_bedrock_response + input_list = ['First string', 'Second string', 'Third string'] + + # Call method + result = await bedrock_embedder.create(input_list) + + # Verify request body contains joined strings + call_args = mock_boto3_client.invoke_model.call_args + body = json.loads(call_args[1]['body']) + assert body['inputText'] == 'First string Second string Third string' + + # Verify result + assert isinstance(result, list) + assert len(result) == 1024 + + @pytest.mark.asyncio + async def test_create_with_other_input( + self, bedrock_embedder, mock_boto3_client, mock_bedrock_response + ): + """Test create method with non-string, non-list input.""" + # Setup + mock_boto3_client.invoke_model.return_value = mock_bedrock_response + input_data = 12345 # Integer input + + # Call method + result = await bedrock_embedder.create(input_data) + + # Verify request body contains string representation + call_args = mock_boto3_client.invoke_model.call_args + body = json.loads(call_args[1]['body']) + assert body['inputText'] == '12345' + + # Verify result + assert isinstance(result, list) + assert len(result) == 1024 + + @pytest.mark.asyncio + async def test_create_api_error(self, bedrock_embedder, mock_boto3_client): + """Test handling of API errors.""" + # Setup mock to raise exception + mock_boto3_client.invoke_model.side_effect = Exception('API Error') + + # Call method and verify exception is raised + with pytest.raises(Exception) as exc_info: + await bedrock_embedder.create('test input') + + assert 'API Error' in str(exc_info.value) + + @pytest.mark.asyncio + async def test_create_invalid_response_format(self, bedrock_embedder, mock_boto3_client): + """Test handling of invalid response format.""" + # Setup mock response with missing embedding field + mock_response = {'body': MagicMock()} + mock_response_body = {'invalid': 'response'} + mock_response['body'].read.return_value.decode.return_value = json.dumps(mock_response_body) + mock_boto3_client.invoke_model.return_value = mock_response + + # Call method and verify KeyError is raised + with pytest.raises(KeyError): + await bedrock_embedder.create('test input') + + +class TestAmazonBedrockEmbedderCreateBatch: + """Tests for AmazonBedrockEmbedder create_batch method.""" + + @pytest.mark.asyncio + async def test_create_batch_multiple_inputs( + self, bedrock_embedder, mock_boto3_client, mock_bedrock_response + ): + """Test create_batch method with multiple inputs.""" + # Setup + mock_boto3_client.invoke_model.return_value = mock_bedrock_response + input_batch = ['First text', 'Second text', 'Third text'] + + # Call method + result = await bedrock_embedder.create_batch(input_batch) + + # Verify API was called for each input + assert mock_boto3_client.invoke_model.call_count == 3 + + # Verify result structure + assert isinstance(result, list) + assert len(result) == 3 + assert all(isinstance(embedding, list) for embedding in result) + assert all(len(embedding) == 1024 for embedding in result) + + @pytest.mark.asyncio + async def test_create_batch_empty_list(self, bedrock_embedder, mock_boto3_client): + """Test create_batch method with empty input list.""" + # Call method + result = await bedrock_embedder.create_batch([]) + + # Verify no API calls were made + mock_boto3_client.invoke_model.assert_not_called() + + # Verify result + assert result == [] + + @pytest.mark.asyncio + async def test_create_batch_single_input( + self, bedrock_embedder, mock_boto3_client, mock_bedrock_response + ): + """Test create_batch method with single input.""" + # Setup + mock_boto3_client.invoke_model.return_value = mock_bedrock_response + input_batch = ['Single text'] + + # Call method + result = await bedrock_embedder.create_batch(input_batch) + + # Verify API was called once + assert mock_boto3_client.invoke_model.call_count == 1 + + # Verify result structure + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], list) + assert len(result[0]) == 1024 + + +if __name__ == '__main__': + pytest.main(['-xvs', __file__]) diff --git a/tests/embedder/test_amazon_bedrock_int.py b/tests/embedder/test_amazon_bedrock_int.py new file mode 100644 index 000000000..914b466c6 --- /dev/null +++ b/tests/embedder/test_amazon_bedrock_int.py @@ -0,0 +1,136 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Running tests: pytest -xvs tests/embedder/test_amazon_bedrock_int.py +# Requires: AWS credentials configured and Bedrock model access + +import os +from pathlib import Path + +import pytest + +from graphiti_core.embedder.amazon_bedrock import ( + DEFAULT_EMBEDDING_MODEL, + AmazonBedrockEmbedder, + AmazonBedrockEmbedderConfig, +) + + +def _has_aws_credentials(): + """Check if AWS credentials are available.""" + # Check environment variables + if os.getenv('AWS_ACCESS_KEY_ID') or os.getenv('AWS_PROFILE'): + return True + + # Check for default credentials file + credentials_file = Path.home() / '.aws' / 'credentials' + return credentials_file.exists() + + +# Skip all tests if AWS credentials not available +pytestmark = pytest.mark.skipif(not _has_aws_credentials(), reason='AWS credentials not configured') + + +@pytest.fixture +def bedrock_embedder(): + """Create a real AmazonBedrockEmbedder for integration testing.""" + config = AmazonBedrockEmbedderConfig(model=DEFAULT_EMBEDDING_MODEL, region='us-east-1') + return AmazonBedrockEmbedder(config=config) + + +class TestAmazonBedrockEmbedderIntegration: + """Integration tests for AmazonBedrockEmbedder with real AWS Bedrock.""" + + @pytest.mark.asyncio + async def test_create_single_embedding(self, bedrock_embedder): + """Test creating a single embedding.""" + text = 'This is a test sentence for embedding.' + + result = await bedrock_embedder.create(text) + + assert isinstance(result, list) + assert len(result) > 0 + assert all(isinstance(x, float) for x in result) + # Amazon Titan embeddings are typically 1024 or 1536 dimensions + assert len(result) in [1024, 1536] + + @pytest.mark.asyncio + async def test_create_batch_embeddings(self, bedrock_embedder): + """Test creating multiple embeddings in batch.""" + texts = ['First test sentence.', 'Second test sentence.', 'Third test sentence.'] + + result = await bedrock_embedder.create_batch(texts) + + assert isinstance(result, list) + assert len(result) == 3 + assert all(isinstance(embedding, list) for embedding in result) + assert all(len(embedding) > 0 for embedding in result) + assert all(isinstance(x, float) for embedding in result for x in embedding) + + # All embeddings should have the same dimension + dimensions = [len(embedding) for embedding in result] + assert all(dim == dimensions[0] for dim in dimensions) + + @pytest.mark.asyncio + async def test_embedding_similarity(self, bedrock_embedder): + """Test that similar texts have similar embeddings.""" + similar_texts = ['The cat sat on the mat.', 'A cat was sitting on a mat.'] + different_text = 'Quantum physics is fascinating.' + + # Get embeddings + similar_embeddings = await bedrock_embedder.create_batch(similar_texts) + different_embedding = await bedrock_embedder.create(different_text) + + # Calculate cosine similarity (simplified) + def cosine_similarity(a, b): + dot_product = sum(x * y for x, y in zip(a, b, strict=False)) + norm_a = sum(x * x for x in a) ** 0.5 + norm_b = sum(x * x for x in b) ** 0.5 + return dot_product / (norm_a * norm_b) + + # Similar texts should have higher similarity than different text + similar_similarity = cosine_similarity(similar_embeddings[0], similar_embeddings[1]) + different_similarity = cosine_similarity(similar_embeddings[0], different_embedding) + + assert similar_similarity > different_similarity + assert similar_similarity > 0.7 # Should be quite similar + + @pytest.mark.asyncio + async def test_different_input_types(self, bedrock_embedder): + """Test embedding different input types.""" + # String input + string_result = await bedrock_embedder.create('Test string') + + # List input + list_result = await bedrock_embedder.create(['Test', 'string', 'list']) + + # Both should return valid embeddings + assert isinstance(string_result, list) + assert isinstance(list_result, list) + assert len(string_result) > 0 + assert len(list_result) > 0 + assert all(isinstance(x, float) for x in string_result) + assert all(isinstance(x, float) for x in list_result) + + @pytest.mark.asyncio + async def test_empty_batch(self, bedrock_embedder): + """Test handling of empty batch.""" + result = await bedrock_embedder.create_batch([]) + assert result == [] + + +if __name__ == '__main__': + pytest.main(['-v', 'test_amazon_bedrock_int.py']) diff --git a/tests/llm_client/test_amazon_bedrock_client.py b/tests/llm_client/test_amazon_bedrock_client.py new file mode 100644 index 000000000..f6fe08168 --- /dev/null +++ b/tests/llm_client/test_amazon_bedrock_client.py @@ -0,0 +1,231 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Running tests: pytest -xvs tests/llm_client/test_amazon_bedrock_client.py + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import BaseModel + +from graphiti_core.llm_client.amazon_bedrock_client import AmazonBedrockLLMClient +from graphiti_core.llm_client.config import LLMConfig +from graphiti_core.llm_client.errors import RateLimitError +from graphiti_core.prompts.models import Message + + +class ResponseModel(BaseModel): + """Test model for response testing.""" + + test_field: str + optional_field: int = 0 + + +@pytest.fixture +def mock_boto3_client(): + """Fixture to mock the boto3 bedrock-runtime client.""" + with patch('boto3.client') as mock_client: + mock_instance = MagicMock() + mock_client.return_value = mock_instance + yield mock_instance + + +@pytest.fixture +def bedrock_client(mock_boto3_client): + """Fixture to create an AmazonBedrockLLMClient with a mocked boto3 client.""" + config = LLMConfig( + model='us.anthropic.claude-sonnet-4-20250514-v1:0', temperature=0.5, max_tokens=1000 + ) + client = AmazonBedrockLLMClient(config=config, region='us-east-1') + client.client = mock_boto3_client + return client + + +class TestAmazonBedrockLLMClientInitialization: + """Tests for AmazonBedrockLLMClient initialization.""" + + def test_init_with_config(self): + """Test initialization with a config object.""" + config = LLMConfig(model='test-model', temperature=0.5, max_tokens=1000) + client = AmazonBedrockLLMClient(config=config, max_tokens=1000, region='eu-west-1') + + assert client.model == 'test-model' + assert client.temperature == 0.5 + assert client.max_tokens == 1000 + assert client.region == 'eu-west-1' + + def test_init_without_config(self): + """Test initialization without a config uses defaults.""" + client = AmazonBedrockLLMClient(region='ap-southeast-1') + + assert client.model == 'us.anthropic.claude-sonnet-4-20250514-v1:0' + assert client.temperature == 0.7 + assert client.region == 'ap-southeast-1' + + def test_init_default_region(self): + """Test initialization with default region.""" + client = AmazonBedrockLLMClient() + assert client.region == 'us-east-1' + + +class TestAmazonBedrockLLMClientGenerateResponse: + """Tests for AmazonBedrockLLMClient generate_response method.""" + + @pytest.mark.asyncio + async def test_generate_response_without_model(self, bedrock_client, mock_boto3_client): + """Test successful response generation without response model.""" + # Setup mock response + mock_response_body = {'content': [{'text': 'This is a test response'}]} + mock_response = {'body': MagicMock()} + mock_response['body'].read.return_value.decode.return_value = json.dumps(mock_response_body) + mock_boto3_client.invoke_model.return_value = mock_response + + # Call method + messages = [ + Message(role='system', content='System message'), + Message(role='user', content='User message'), + ] + + with patch('asyncio.get_event_loop') as mock_loop: + mock_executor = AsyncMock() + mock_executor.return_value = mock_response + mock_loop.return_value.run_in_executor = mock_executor + + result = await bedrock_client._generate_response(messages=messages) + + # Assertions + assert isinstance(result, dict) + assert result['content'] == 'This is a test response' + mock_executor.assert_called_once() + + @pytest.mark.asyncio + async def test_generate_response_with_model(self, bedrock_client, mock_boto3_client): + """Test successful response generation with response model.""" + # Setup mock response with valid JSON + mock_response_body = { + 'content': [{'text': '{"test_field": "test_value", "optional_field": 42}'}] + } + mock_response = {'body': MagicMock()} + mock_response['body'].read.return_value.decode.return_value = json.dumps(mock_response_body) + mock_boto3_client.invoke_model.return_value = mock_response + + # Call method + messages = [ + Message(role='system', content='System message'), + Message(role='user', content='User message'), + ] + + with patch('asyncio.get_event_loop') as mock_loop: + mock_executor = AsyncMock() + mock_executor.return_value = mock_response + mock_loop.return_value.run_in_executor = mock_executor + + result = await bedrock_client._generate_response( + messages=messages, response_model=ResponseModel + ) + + # Assertions + assert isinstance(result, dict) + assert result['test_field'] == 'test_value' + assert result['optional_field'] == 42 + + @pytest.mark.asyncio + async def test_generate_response_json_parsing_error(self, bedrock_client, mock_boto3_client): + """Test handling of JSON parsing errors.""" + # Setup mock response with invalid JSON + mock_response_body = {'content': [{'text': 'Invalid JSON response'}]} + mock_response = {'body': MagicMock()} + mock_response['body'].read.return_value.decode.return_value = json.dumps(mock_response_body) + mock_boto3_client.invoke_model.return_value = mock_response + + messages = [Message(role='user', content='Test message')] + + with patch('asyncio.get_event_loop') as mock_loop: + mock_executor = AsyncMock() + mock_executor.return_value = mock_response + mock_loop.return_value.run_in_executor = mock_executor + + with pytest.raises((ValueError, json.JSONDecodeError)): + await bedrock_client._generate_response( + messages=messages, response_model=ResponseModel + ) + + @pytest.mark.asyncio + async def test_rate_limit_error(self, bedrock_client, mock_boto3_client): + """Test handling of rate limit errors.""" + messages = [Message(role='user', content='Test message')] + + with patch('asyncio.get_event_loop') as mock_loop: + mock_executor = AsyncMock() + mock_executor.side_effect = Exception('Throttling error occurred') + mock_loop.return_value.run_in_executor = mock_executor + + with pytest.raises(RateLimitError): + await bedrock_client._generate_response(messages=messages) + + @pytest.mark.asyncio + async def test_clean_json_response(self, bedrock_client): + """Test the _clean_json_response method.""" + # Test with markdown code blocks + text_with_markdown = '```json\n{"test": "value"}\n```' + result = bedrock_client._clean_json_response(text_with_markdown) + assert result == '{"test": "value"}' + + # Test with double braces + text_with_double_braces = '{{"test": "value"}}' + result = bedrock_client._clean_json_response(text_with_double_braces) + assert result == '{"test": "value"}' + + # Test with extra text around JSON + text_with_extra = 'Here is the JSON: {"test": "value"} and some more text' + result = bedrock_client._clean_json_response(text_with_extra) + assert result == '{"test": "value"}' + + @pytest.mark.asyncio + async def test_invoke_bedrock_model_system_message(self, bedrock_client, mock_boto3_client): + """Test that system messages are handled correctly.""" + # Setup mock response + mock_response_body = {'content': [{'text': 'Response'}]} + mock_response = {'body': MagicMock()} + mock_response['body'].read.return_value.decode.return_value = json.dumps(mock_response_body) + mock_boto3_client.invoke_model.return_value = mock_response + + messages = [ + {'role': 'system', 'content': 'You are a helpful assistant'}, + {'role': 'user', 'content': 'Hello'}, + ] + + with patch('asyncio.get_event_loop') as mock_loop: + mock_executor = AsyncMock() + mock_executor.return_value = mock_response + mock_loop.return_value.run_in_executor = mock_executor + + await bedrock_client._invoke_bedrock_model( + model='test-model', + messages=messages, + temperature=0.7, + max_tokens=100, + response_format='text', + ) + + # Verify the call was made with system prompt + mock_executor.assert_called_once() + # We can't easily inspect the lambda, but we know it was called + + +if __name__ == '__main__': + pytest.main(['-v', 'test_amazon_bedrock_client.py']) diff --git a/tests/llm_client/test_amazon_bedrock_client_int.py b/tests/llm_client/test_amazon_bedrock_client_int.py new file mode 100644 index 000000000..c88efedab --- /dev/null +++ b/tests/llm_client/test_amazon_bedrock_client_int.py @@ -0,0 +1,120 @@ +""" +Copyright 2024, Zep Software, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Running tests: pytest -xvs tests/llm_client/test_amazon_bedrock_client_int.py +# Requires: AWS credentials configured and Bedrock model access + +import os +from pathlib import Path + +import pytest +from pydantic import BaseModel + +from graphiti_core.llm_client.amazon_bedrock_client import AmazonBedrockLLMClient +from graphiti_core.llm_client.config import LLMConfig +from graphiti_core.prompts.models import Message + + +def _has_aws_credentials(): + """Check if AWS credentials are available.""" + # Check environment variables + if os.getenv('AWS_ACCESS_KEY_ID') or os.getenv('AWS_PROFILE'): + return True + + # Check for default credentials file + credentials_file = Path.home() / '.aws' / 'credentials' + return credentials_file.exists() + + +# Skip all tests if AWS credentials not available +pytestmark = pytest.mark.skipif(not _has_aws_credentials(), reason='AWS credentials not configured') + + +class ResponseModel(BaseModel): + """Test model for structured output.""" + + answer: str + confidence: float + + +@pytest.fixture +def bedrock_client(): + """Create a real AmazonBedrockLLMClient for integration testing.""" + config = LLMConfig( + model='us.anthropic.claude-sonnet-4-20250514-v1:0', temperature=0.1, max_tokens=100 + ) + return AmazonBedrockLLMClient(config=config, region='us-east-1') + + +class TestAmazonBedrockLLMClientIntegration: + """Integration tests for AmazonBedrockLLMClient with real AWS Bedrock.""" + + @pytest.mark.asyncio + async def test_simple_text_generation(self, bedrock_client): + """Test basic text generation without structured output.""" + messages = [ + Message(role='system', content='You are a helpful assistant.'), + Message(role='user', content='What is 2+2? Answer briefly.'), + ] + + result = await bedrock_client._generate_response(messages) + + assert isinstance(result, dict) + assert 'content' in result + assert isinstance(result['content'], str) + assert len(result['content']) > 0 + assert '4' in result['content'] + + @pytest.mark.asyncio + async def test_structured_output_generation(self, bedrock_client): + """Test structured output generation with Pydantic model.""" + messages = [ + Message(role='system', content='You are a helpful assistant.'), + Message( + role='user', content='What is the capital of France? Provide your confidence level.' + ), + ] + + result = await bedrock_client._generate_response(messages, response_model=ResponseModel) + + assert isinstance(result, dict) + assert 'answer' in result + assert 'confidence' in result + assert isinstance(result['answer'], str) + assert isinstance(result['confidence'], float) + assert 'Paris' in result['answer'] + assert 0.0 <= result['confidence'] <= 1.0 + + @pytest.mark.asyncio + async def test_different_regions(self): + """Test client works in different supported regions.""" + # Test with a different region + config = LLMConfig( + model='us.anthropic.claude-sonnet-4-20250514-v1:0', temperature=0.1, max_tokens=50 + ) + client = AmazonBedrockLLMClient(config=config, region='us-west-2') + + messages = [Message(role='user', content='Say hello in one word.')] + + result = await client._generate_response(messages) + + assert isinstance(result, dict) + assert 'content' in result + assert len(result['content']) > 0 + + +if __name__ == '__main__': + pytest.main(['-v', 'test_amazon_bedrock_client_int.py'])