diff --git a/src/any_llm/providers/bedrock/bedrock.py b/src/any_llm/providers/bedrock/bedrock.py index cce7594f..8feeefac 100644 --- a/src/any_llm/providers/bedrock/bedrock.py +++ b/src/any_llm/providers/bedrock/bedrock.py @@ -37,7 +37,7 @@ class BedrockProvider(AnyLLM): SUPPORTS_COMPLETION = True SUPPORTS_RESPONSES = False SUPPORTS_COMPLETION_REASONING = True - SUPPORTS_COMPLETION_IMAGE = False + SUPPORTS_COMPLETION_IMAGE = True SUPPORTS_COMPLETION_PDF = False SUPPORTS_EMBEDDING = True SUPPORTS_LIST_MODELS = True diff --git a/src/any_llm/providers/bedrock/utils.py b/src/any_llm/providers/bedrock/utils.py index 27a0650f..4b16f64d 100644 --- a/src/any_llm/providers/bedrock/utils.py +++ b/src/any_llm/providers/bedrock/utils.py @@ -114,15 +114,82 @@ def _convert_messages(messages: list[dict[str, Any]]) -> tuple[list[dict[str, An if bedrock_message: formatted_messages.append(bedrock_message) else: # user messages + if isinstance(message.get("content"), list): + # Handle messages with structured content (e.g., images) + content = _convert_images(message["content"]) + else: + # Handle simple text messages + content = [{"text": message["content"]}] formatted_messages.append( { "role": message["role"], - "content": [{"text": message["content"]}], + "content": content, } ) return system_message, formatted_messages +def _convert_images(content: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Convert images from OpenAI format to AWS Bedrock format. + + - Parse the "content" field block by block + - Convert image blocks to Bedrock format + - Support both URL and base64 encoded images + """ + converted_content = [] + for block in content: + if block.get("type") == "image_url": + converted_block: dict[str, Any] = {"image": {}} + url = block.get("image_url", {}).get("url", "") + + if url.startswith("data:image/"): + # Handle base64 encoded images + mime_part = url[5:] # Remove "data:" + semi_idx = mime_part.find(";") + if semi_idx != -1: + media_type = mime_part[:semi_idx] + # Extract format from media type (e.g., "image/png" -> "png") + format_type = media_type.split("/")[1] if "/" in media_type else "png" + base64_data = url.split("base64,", 1)[1] if "base64," in url else "" + converted_block["image"] = { + "format": format_type, + "source": { + "type": "base64", + "data": base64_data, + } + } + else: + # Fallback if parsing fails + converted_block["image"] = { + "format": "png", + "source": { + "type": "base64", + "data": url.split("base64,", 1)[1] if "base64," in url else "", + } + } + else: + # Handle URL images - need to determine format from URL or default to png + format_type = "png" # Default format + if url.lower().endswith((".jpg", ".jpeg")): + format_type = "jpeg" + elif url.lower().endswith(".png"): + format_type = "png" + elif url.lower().endswith(".gif"): + format_type = "gif" + elif url.lower().endswith(".webp"): + format_type = "webp" + + converted_block["image"] = { + "format": format_type, + "source": { + "type": "url", + "url": url, + } + } + converted_content.append(converted_block) + else: + converted_content.append(block) + return converted_content def _convert_tool_result(message: dict[str, Any]) -> dict[str, Any] | None: """Convert OpenAI tool result format to AWS Bedrock format.""" diff --git a/tests/unit/providers/test_aws_provider.py b/tests/unit/providers/test_aws_provider.py index a14e26b0..8fb05a0a 100644 --- a/tests/unit/providers/test_aws_provider.py +++ b/tests/unit/providers/test_aws_provider.py @@ -150,3 +150,36 @@ def test_embedding_list_of_strings() -> None: assert response.data[1].index == 1 assert response.usage.prompt_tokens == 11 assert response.usage.total_tokens == 11 + + +def test_completion_with_images() -> None: + """Test that completion correctly processes messages with images.""" + model_id = "model-id" + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Some question about these images."}, + {"type": "image_url", "image_url": {"url": "https://example.com/a.png"}}, + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,qwertyuiopasdfghjklzxcvbnm"}}, + ], + } + ] + + with mock_aws_provider() as mock_boto3_client: + provider = BedrockProvider(api_key="test_key") + provider._completion(CompletionParams(model_id=model_id, messages=messages)) + + mock_boto3_client.return_value.converse.assert_called_once_with( + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "Some question about these images."}, + {"image": {"format": "png", "source": {"type": "url", "url": "https://example.com/a.png"}}}, + {"image": {"format": "jpeg", "source": {"type": "base64", "data": "qwertyuiopasdfghjklzxcvbnm"}}}, + ], + } + ], + modelId=model_id, + )