Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 194 additions & 8 deletions rlm/clients/gemini.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import base64
import os
from collections import defaultdict
from pathlib import Path
from typing import Any

from dotenv import load_dotenv
Expand All @@ -11,6 +13,119 @@

load_dotenv()


def _load_image_as_part(image_source: str | dict) -> types.Part:
"""Load an image and return a Gemini Part object.

Args:
image_source: Either a file path (str), URL (str starting with http),
or a dict with 'type' and 'data' keys for base64 images.

Returns:
A Gemini Part object containing the image.
"""
if isinstance(image_source, dict):
# Base64 encoded image: {"type": "base64", "media_type": "image/png", "data": "..."}
if image_source.get("type") == "base64":
image_bytes = base64.b64decode(image_source["data"])
mime_type = image_source.get("media_type", "image/png")
return types.Part.from_bytes(data=image_bytes, mime_type=mime_type)
# URL format from OpenAI-style: {"type": "image_url", "image_url": {"url": "..."}}
elif image_source.get("type") == "image_url":
url = image_source["image_url"]["url"]
if url.startswith("data:"):
# Data URL: data:image/png;base64,...
header, data = url.split(",", 1)
mime_type = header.split(":")[1].split(";")[0]
image_bytes = base64.b64decode(data)
return types.Part.from_bytes(data=image_bytes, mime_type=mime_type)
else:
return types.Part.from_uri(file_uri=url, mime_type="image/jpeg")
elif isinstance(image_source, str):
if image_source.startswith(("http://", "https://")):
# URL
return types.Part.from_uri(file_uri=image_source, mime_type="image/jpeg")
else:
# Local file path
path = Path(image_source)
if path.exists():
mime_type = _get_mime_type(path)
with open(path, "rb") as f:
return types.Part.from_bytes(data=f.read(), mime_type=mime_type)
else:
raise FileNotFoundError(f"Image file not found: {image_source}")
raise ValueError(f"Unsupported image source type: {type(image_source)}")


def _get_mime_type(path: Path) -> str:
"""Get MIME type from file extension."""
suffix = path.suffix.lower()
mime_types = {
# Images
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".gif": "image/gif",
".webp": "image/webp",
".bmp": "image/bmp",
# Audio
".mp3": "audio/mpeg",
".wav": "audio/wav",
".ogg": "audio/ogg",
".flac": "audio/flac",
".m4a": "audio/mp4",
".aac": "audio/aac",
".webm": "audio/webm",
# Video
".mp4": "video/mp4",
".mpeg": "video/mpeg",
".mov": "video/quicktime",
".avi": "video/x-msvideo",
".mkv": "video/x-matroska",
}
return mime_types.get(suffix, "application/octet-stream")


def _load_audio_as_part(audio_source: str | dict) -> types.Part:
"""Load an audio file and return a Gemini Part object.

Args:
audio_source: Either a file path (str), URL (str starting with http),
or a dict with 'type' and 'data' keys for base64 audio.

Returns:
A Gemini Part object containing the audio.
"""
if isinstance(audio_source, dict):
# Base64 encoded audio
if audio_source.get("type") == "base64":
audio_bytes = base64.b64decode(audio_source["data"])
mime_type = audio_source.get("media_type", "audio/mpeg")
return types.Part.from_bytes(data=audio_bytes, mime_type=mime_type)
# Path format
elif audio_source.get("type") == "audio_path":
path = Path(audio_source.get("path", ""))
if path.exists():
mime_type = _get_mime_type(path)
with open(path, "rb") as f:
return types.Part.from_bytes(data=f.read(), mime_type=mime_type)
else:
raise FileNotFoundError(f"Audio file not found: {audio_source.get('path')}")
elif isinstance(audio_source, str):
if audio_source.startswith(("http://", "https://")):
# URL - let Gemini fetch it
return types.Part.from_uri(file_uri=audio_source, mime_type="audio/mpeg")
else:
# Local file path
path = Path(audio_source)
if path.exists():
mime_type = _get_mime_type(path)
with open(path, "rb") as f:
return types.Part.from_bytes(data=f.read(), mime_type=mime_type)
else:
raise FileNotFoundError(f"Audio file not found: {audio_source}")
raise ValueError(f"Unsupported audio source type: {type(audio_source)}")

DEFAULT_GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")


Expand Down Expand Up @@ -95,7 +210,18 @@ async def acompletion(
def _prepare_contents(
self, prompt: str | list[dict[str, Any]]
) -> tuple[list[types.Content] | str, str | None]:
"""Prepare contents and extract system instruction for Gemini API."""
"""Prepare contents and extract system instruction for Gemini API.

Supports multimodal content where message content can be:
- A string (text only)
- A list of content items (text and images mixed)

Image items can be:
- {"type": "text", "text": "..."}
- {"type": "image_url", "image_url": {"url": "..."}}
- {"type": "image_path", "path": "/path/to/image.png"}
- {"type": "base64", "media_type": "image/png", "data": "..."}
"""
system_instruction = None

if isinstance(prompt, str):
Expand All @@ -110,20 +236,80 @@ def _prepare_contents(

if role == "system":
# Gemini handles system instruction separately
system_instruction = content
elif role == "user":
contents.append(types.Content(role="user", parts=[types.Part(text=content)]))
elif role == "assistant":
# Gemini uses "model" instead of "assistant"
contents.append(types.Content(role="model", parts=[types.Part(text=content)]))
if isinstance(content, str):
system_instruction = content
elif isinstance(content, list):
# Extract text from system message list
system_parts = []
for item in content:
if isinstance(item, str):
system_parts.append(item)
elif isinstance(item, dict) and item.get("type") == "text":
system_parts.append(item.get("text", ""))
system_instruction = "\n".join(system_parts)
elif role in ("user", "assistant"):
gemini_role = "user" if role == "user" else "model"
parts = self._content_to_parts(content)
if parts:
contents.append(types.Content(role=gemini_role, parts=parts))
else:
# Default to user role for unknown roles
contents.append(types.Content(role="user", parts=[types.Part(text=content)]))
parts = self._content_to_parts(content)
if parts:
contents.append(types.Content(role="user", parts=parts))

return contents, system_instruction

raise ValueError(f"Invalid prompt type: {type(prompt)}")

def _content_to_parts(self, content: str | list) -> list[types.Part]:
"""Convert message content to Gemini Parts.

Args:
content: Either a string or a list of content items.

Returns:
List of Gemini Part objects.
"""
if isinstance(content, str):
return [types.Part(text=content)]

if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, str):
parts.append(types.Part(text=item))
elif isinstance(item, dict):
item_type = item.get("type", "text")
if item_type == "text":
parts.append(types.Part(text=item.get("text", "")))
elif item_type in ("image_url", "image_path", "base64"):
try:
# Use image_path for local files
if item_type == "image_path":
image_part = _load_image_as_part(item.get("path", ""))
else:
image_part = _load_image_as_part(item)
parts.append(image_part)
except Exception as e:
# If image loading fails, add error as text
parts.append(types.Part(text=f"[Image load error: {e}]"))
elif item_type == "audio_path":
try:
audio_part = _load_audio_as_part(item.get("path", ""))
parts.append(audio_part)
except Exception as e:
parts.append(types.Part(text=f"[Audio load error: {e}]"))
elif item_type == "audio_url":
try:
audio_part = _load_audio_as_part(item.get("url", ""))
parts.append(audio_part)
except Exception as e:
parts.append(types.Part(text=f"[Audio load error: {e}]"))
return parts

return [types.Part(text=str(content))]

def _track_cost(self, response: types.GenerateContentResponse, model: str):
self.model_call_counts[model] += 1

Expand Down
Loading
Loading