Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ export PRIME_API_KEY=...


### Model Providers
We currently support most major clients (OpenAI, Anthropic), as well as the router platforms (OpenRouter, Portkey, LiteLLM). For local models, we recommend using vLLM (which interfaces with the [OpenAI client](https://github.com/alexzhang13/rlm/blob/main/rlm/clients/openai.py)). To view or add support for more clients, start by looking at [`rlm/clients/`](https://github.com/alexzhang13/rlm/tree/main/rlm/clients).
We currently support most major clients (OpenAI, Anthropic, Google Gemini), as well as the router platforms (OpenRouter, Portkey, LiteLLM). For local models, we recommend using vLLM (which interfaces with the [OpenAI client](https://github.com/alexzhang13/rlm/blob/main/rlm/clients/openai.py)). To view or add support for more clients, start by looking at [`rlm/clients/`](https://github.com/alexzhang13/rlm/tree/main/rlm/clients).

## Relevant Reading
* **[Dec '25]** [Recursive Language Models arXiv](https://arxiv.org/abs/2512.24601)
Expand Down
169 changes: 155 additions & 14 deletions rlm/clients/gemini.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from collections import defaultdict
from collections.abc import AsyncIterator, Iterator
from typing import Any

from dotenv import load_dotenv
Expand All @@ -12,6 +13,10 @@
load_dotenv()

DEFAULT_GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
DEFAULT_VERTEXAI = os.getenv("GOOGLE_GENAI_USE_VERTEXAI")
DEFAULT_PROJECT = os.getenv("GOOGLE_CLOUD_PROJECT")
DEFAULT_LOCATION = os.getenv("GOOGLE_CLOUD_LOCATION")
DEFAULT_MODEL_NAME = "gemini-2.5-flash"


class GeminiClient(BaseLM):
Expand All @@ -23,22 +28,59 @@ class GeminiClient(BaseLM):
def __init__(
self,
api_key: str | None = None,
model_name: str | None = "gemini-2.5-flash",
model_name: str | None = DEFAULT_MODEL_NAME,
vertexai: bool = False,
project: str | None = None,
location: str | None = None,
**kwargs,
):
super().__init__(model_name=model_name, **kwargs)
"""
Initialize the Gemini Client.

Args:
model_name: The ID of the model to use.
api_key: API key for Gemini API.
vertexai: If True, use Vertex AI.
project: Google Cloud project ID (required if vertexai=True).
location: Google Cloud location (required if vertexai=True).
**kwargs: Additional arguments passed to the genai.Client.
Supported kwargs:
- safety_settings: Optional safety settings configuration for content filtering.
"""

if api_key is None:
api_key = DEFAULT_GEMINI_API_KEY
super().__init__(model_name=model_name, **kwargs)

if api_key is None:
api_key = api_key or DEFAULT_GEMINI_API_KEY
vertexai = vertexai or DEFAULT_VERTEXAI
project = project or DEFAULT_PROJECT
location = location or DEFAULT_LOCATION

# Optional safety settings configuration
self.safety_settings = kwargs.pop("safety_settings", None)

# Try Gemini API first (unless vertexai is explicitly True)
if not vertexai and api_key:
self.client = genai.Client(api_key=api_key, **kwargs)
# If vertexai=True or we don't have a Gemini API key, try Vertex AI
elif vertexai or (not api_key and (project or location)):
if not project or not location:
raise ValueError(
"Vertex AI requires a project ID and location. "
"Set it via `project` and `location` arguments or `GOOGLE_CLOUD_PROJECT` and `GOOGLE_CLOUD_LOCATION` environment variables."
)
self.client = genai.Client(
vertexai=True,
project=project,
location=location,
**kwargs,
)
# No valid configuration found
else:
raise ValueError(
"Gemini API key is required. Set GEMINI_API_KEY env var or pass api_key."
" For Vertex AI, ensure GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION are set."
)

self.client = genai.Client(api_key=api_key)
self.model_name = model_name

# Per-model usage tracking
self.model_call_counts: dict[str, int] = defaultdict(int)
self.model_input_tokens: dict[str, int] = defaultdict(int)
Expand All @@ -56,9 +98,13 @@ def completion(self, prompt: str | list[dict[str, Any]], model: str | None = Non
if not model:
raise ValueError("Model name is required for Gemini client.")

config = None
config = types.GenerateContentConfig()
if system_instruction:
config = types.GenerateContentConfig(system_instruction=system_instruction)
config.system_instruction = system_instruction

# Apply safety settings if configured
if self.safety_settings:
config.safety_settings = self.safety_settings

response = self.client.models.generate_content(
model=model,
Expand All @@ -69,6 +115,50 @@ def completion(self, prompt: str | list[dict[str, Any]], model: str | None = Non
self._track_cost(response, model)
return response.text

def completion_stream(
self,
prompt: str | list[dict[str, Any]],
model: str | None = None,
) -> Iterator[str]:
"""
Stream completion tokens from the Gemini model.

Args:
prompt: The prompt (string or message list)
model: Optional model override

Yields:
Token chunks as they arrive
"""

model = model or self.model_name
if not model:
raise ValueError("Model name is required for Gemini client.")

contents, system_instruction = self._prepare_contents(prompt)

config = types.GenerateContentConfig()
if system_instruction:
config.system_instruction = system_instruction

if self.safety_settings:
config.safety_settings = self.safety_settings

response_stream = self.client.models.generate_content_stream(
model=model, contents=contents, config=config
)

# Track the last chunk for usage metadata
last_chunk = None
for chunk in response_stream:
last_chunk = chunk
if chunk.text:
yield chunk.text

# Track usage from the final chunk if available
if last_chunk:
self._track_cost(last_chunk, model)

async def acompletion(
self, prompt: str | list[dict[str, Any]], model: str | None = None
) -> str:
Expand All @@ -78,9 +168,12 @@ async def acompletion(
if not model:
raise ValueError("Model name is required for Gemini client.")

config = None
config = types.GenerateContentConfig()
if system_instruction:
config = types.GenerateContentConfig(system_instruction=system_instruction)
config.system_instruction = system_instruction

if self.safety_settings:
config.safety_settings = self.safety_settings

# google-genai SDK supports async via aio interface
response = await self.client.aio.models.generate_content(
Expand All @@ -92,6 +185,50 @@ async def acompletion(
self._track_cost(response, model)
return response.text

async def acompletion_stream(
self,
prompt: str | list[dict[str, Any]],
model: str | None = None,
) -> AsyncIterator[str]:
"""
Async stream completion tokens from the Gemini model.

Args:
prompt: The prompt (string or message list)
model: Optional model override

Yields:
Token chunks as they arrive
"""

model = model or self.model_name
if not model:
raise ValueError("Model name is required for Gemini client.")

contents, system_instruction = self._prepare_contents(prompt)

config = types.GenerateContentConfig()
if system_instruction:
config.system_instruction = system_instruction

if self.safety_settings:
config.safety_settings = self.safety_settings

response_stream = await self.client.aio.models.generate_content_stream(
model=model, contents=contents, config=config
)

# Track the last chunk for usage metadata
last_chunk = None
async for chunk in response_stream:
last_chunk = chunk
if chunk.text:
yield chunk.text

# Track usage from the final chunk if available
if last_chunk:
self._track_cost(last_chunk, model)

def _prepare_contents(
self, prompt: str | list[dict[str, Any]]
) -> tuple[list[types.Content] | str, str | None]:
Expand All @@ -110,7 +247,10 @@ def _prepare_contents(

if role == "system":
# Gemini handles system instruction separately
system_instruction = content
if system_instruction:
system_instruction += "\n" + content
else:
system_instruction = content
elif role == "user":
contents.append(types.Content(role="user", parts=[types.Part(text=content)]))
elif role == "assistant":
Expand All @@ -132,10 +272,11 @@ def _track_cost(self, response: types.GenerateContentResponse, model: str):
if usage:
input_tokens = usage.prompt_token_count or 0
output_tokens = usage.candidates_token_count or 0
total_tokens = usage.total_token_count or (input_tokens + output_tokens)

self.model_input_tokens[model] += input_tokens
self.model_output_tokens[model] += output_tokens
self.model_total_tokens[model] += input_tokens + output_tokens
self.model_total_tokens[model] += total_tokens

# Track last call for handler to read
self.last_prompt_tokens = input_tokens
Expand Down
Loading