From 833d902e7cb13c568bce46a3a0c3f5161c838c67 Mon Sep 17 00:00:00 2001 From: Karl Weinmeister Date: Mon, 5 Jan 2026 19:40:51 -0600 Subject: [PATCH 1/2] feat: Add streaming, Vertex AI support, and safety settings to the Gemini client --- README.md | 2 +- rlm/clients/gemini.py | 169 +++++++++++++++++++++++++++++++--- tests/clients/test_gemini.py | 174 ++++++++++++++++++++++++++++++++++- 3 files changed, 328 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 48b5c988..d58137a5 100644 --- a/README.md +++ b/README.md @@ -111,7 +111,7 @@ export PRIME_API_KEY=... ### Model Providers -We currently support most major clients (OpenAI, Anthropic), as well as the router platforms (OpenRouter, Portkey, LiteLLM). For local models, we recommend using vLLM (which interfaces with the [OpenAI client](https://github.com/alexzhang13/rlm/blob/main/rlm/clients/openai.py)). To view or add support for more clients, start by looking at [`rlm/clients/`](https://github.com/alexzhang13/rlm/tree/main/rlm/clients). +We currently support most major clients (OpenAI, Anthropic, Google Gemini), as well as the router platforms (OpenRouter, Portkey, LiteLLM). For local models, we recommend using vLLM (which interfaces with the [OpenAI client](https://github.com/alexzhang13/rlm/blob/main/rlm/clients/openai.py)). To view or add support for more clients, start by looking at [`rlm/clients/`](https://github.com/alexzhang13/rlm/tree/main/rlm/clients). ## Relevant Reading * **[Dec '25]** [Recursive Language Models arXiv](https://arxiv.org/abs/2512.24601) diff --git a/rlm/clients/gemini.py b/rlm/clients/gemini.py index 7f6dc152..ba7f7a2a 100644 --- a/rlm/clients/gemini.py +++ b/rlm/clients/gemini.py @@ -1,5 +1,6 @@ import os from collections import defaultdict +from collections.abc import AsyncIterator, Iterator from typing import Any from dotenv import load_dotenv @@ -12,6 +13,10 @@ load_dotenv() DEFAULT_GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") +DEFAULT_VERTEXAI = os.getenv("GOOGLE_GENAI_USE_VERTEXAI") +DEFAULT_PROJECT = os.getenv("GOOGLE_CLOUD_PROJECT") +DEFAULT_LOCATION = os.getenv("GOOGLE_CLOUD_LOCATION") +DEFAULT_MODEL_NAME = "gemini-2.5-flash" class GeminiClient(BaseLM): @@ -23,22 +28,59 @@ class GeminiClient(BaseLM): def __init__( self, api_key: str | None = None, - model_name: str | None = "gemini-2.5-flash", + model_name: str | None = DEFAULT_MODEL_NAME, + vertexai: bool = False, + project: str | None = None, + location: str | None = None, **kwargs, ): - super().__init__(model_name=model_name, **kwargs) + """ + Initialize the Gemini Client. + + Args: + model_name: The ID of the model to use. + api_key: API key for Gemini API. + vertexai: If True, use Vertex AI. + project: Google Cloud project ID (required if vertexai=True). + location: Google Cloud location (required if vertexai=True). + **kwargs: Additional arguments passed to the genai.Client. + Supported kwargs: + - safety_settings: Optional safety settings configuration for content filtering. + """ - if api_key is None: - api_key = DEFAULT_GEMINI_API_KEY + super().__init__(model_name=model_name, **kwargs) - if api_key is None: + api_key = api_key or DEFAULT_GEMINI_API_KEY + vertexai = vertexai or DEFAULT_VERTEXAI + project = project or DEFAULT_PROJECT + location = location or DEFAULT_LOCATION + + # Optional safety settings configuration + self.safety_settings = kwargs.pop("safety_settings", None) + + # Try Gemini API first (unless vertexai is explicitly True) + if not vertexai and api_key: + self.client = genai.Client(api_key=api_key, **kwargs) + # If vertexai=True or we don't have a Gemini API key, try Vertex AI + elif vertexai or (not api_key and (project or location)): + if not project or not location: + raise ValueError( + "Vertex AI requires a project ID and location. " + "Set it via `project` and `location` arguments or `GOOGLE_CLOUD_PROJECT` and `GOOGLE_CLOUD_LOCATION` environment variables." + ) + self.client = genai.Client( + vertexai=True, + project=project, + location=location, + **kwargs, + ) + # No valid configuration found + else: raise ValueError( "Gemini API key is required. Set GEMINI_API_KEY env var or pass api_key." + " For Vertex AI, ensure GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION are set." ) - self.client = genai.Client(api_key=api_key) - self.model_name = model_name - # Per-model usage tracking self.model_call_counts: dict[str, int] = defaultdict(int) self.model_input_tokens: dict[str, int] = defaultdict(int) @@ -56,9 +98,13 @@ def completion(self, prompt: str | list[dict[str, Any]], model: str | None = Non if not model: raise ValueError("Model name is required for Gemini client.") - config = None + config = types.GenerateContentConfig() if system_instruction: - config = types.GenerateContentConfig(system_instruction=system_instruction) + config.system_instruction = system_instruction + + # Apply safety settings if configured + if self.safety_settings: + config.safety_settings = self.safety_settings response = self.client.models.generate_content( model=model, @@ -69,6 +115,50 @@ def completion(self, prompt: str | list[dict[str, Any]], model: str | None = Non self._track_cost(response, model) return response.text + def completion_stream( + self, + prompt: str | list[dict[str, Any]], + model: str | None = None, + ) -> Iterator[str]: + """ + Stream completion tokens from the Gemini model. + + Args: + prompt: The prompt (string or message list) + model: Optional model override + + Yields: + Token chunks as they arrive + """ + + model = model or self.model_name + if not model: + raise ValueError("Model name is required for Gemini client.") + + contents, system_instruction = self._prepare_contents(prompt) + + config = types.GenerateContentConfig() + if system_instruction: + config.system_instruction = system_instruction + + if self.safety_settings: + config.safety_settings = self.safety_settings + + response_stream = self.client.models.generate_content_stream( + model=model, contents=contents, config=config + ) + + # Track the last chunk for usage metadata + last_chunk = None + for chunk in response_stream: + last_chunk = chunk + if chunk.text: + yield chunk.text + + # Track usage from the final chunk if available + if last_chunk: + self._track_cost(last_chunk, model) + async def acompletion( self, prompt: str | list[dict[str, Any]], model: str | None = None ) -> str: @@ -78,9 +168,12 @@ async def acompletion( if not model: raise ValueError("Model name is required for Gemini client.") - config = None + config = types.GenerateContentConfig() if system_instruction: - config = types.GenerateContentConfig(system_instruction=system_instruction) + config.system_instruction = system_instruction + + if self.safety_settings: + config.safety_settings = self.safety_settings # google-genai SDK supports async via aio interface response = await self.client.aio.models.generate_content( @@ -92,6 +185,50 @@ async def acompletion( self._track_cost(response, model) return response.text + async def acompletion_stream( + self, + prompt: str | list[dict[str, Any]], + model: str | None = None, + ) -> AsyncIterator[str]: + """ + Async stream completion tokens from the Gemini model. + + Args: + prompt: The prompt (string or message list) + model: Optional model override + + Yields: + Token chunks as they arrive + """ + + model = model or self.model_name + if not model: + raise ValueError("Model name is required for Gemini client.") + + contents, system_instruction = self._prepare_contents(prompt) + + config = types.GenerateContentConfig() + if system_instruction: + config.system_instruction = system_instruction + + if self.safety_settings: + config.safety_settings = self.safety_settings + + response_stream = await self.client.aio.models.generate_content_stream( + model=model, contents=contents, config=config + ) + + # Track the last chunk for usage metadata + last_chunk = None + async for chunk in response_stream: + last_chunk = chunk + if chunk.text: + yield chunk.text + + # Track usage from the final chunk if available + if last_chunk: + self._track_cost(last_chunk, model) + def _prepare_contents( self, prompt: str | list[dict[str, Any]] ) -> tuple[list[types.Content] | str, str | None]: @@ -110,7 +247,10 @@ def _prepare_contents( if role == "system": # Gemini handles system instruction separately - system_instruction = content + if system_instruction: + system_instruction += "\n" + content + else: + system_instruction = content elif role == "user": contents.append(types.Content(role="user", parts=[types.Part(text=content)])) elif role == "assistant": @@ -132,10 +272,11 @@ def _track_cost(self, response: types.GenerateContentResponse, model: str): if usage: input_tokens = usage.prompt_token_count or 0 output_tokens = usage.candidates_token_count or 0 + total_tokens = usage.total_token_count or (input_tokens + output_tokens) self.model_input_tokens[model] += input_tokens self.model_output_tokens[model] += output_tokens - self.model_total_tokens[model] += input_tokens + output_tokens + self.model_total_tokens[model] += total_tokens # Track last call for handler to read self.last_prompt_tokens = input_tokens diff --git a/tests/clients/test_gemini.py b/tests/clients/test_gemini.py index 181a3ee8..91944084 100644 --- a/tests/clients/test_gemini.py +++ b/tests/clients/test_gemini.py @@ -1,5 +1,6 @@ """Tests for the Gemini client.""" +import asyncio import os from unittest.mock import MagicMock, patch @@ -30,7 +31,11 @@ def test_init_default_model(self): def test_init_requires_api_key(self): """Test client raises error when no API key provided.""" with patch.dict(os.environ, {}, clear=True): - with patch("rlm.clients.gemini.DEFAULT_GEMINI_API_KEY", None): + with ( + patch("rlm.clients.gemini.DEFAULT_GEMINI_API_KEY", None), + patch("rlm.clients.gemini.DEFAULT_PROJECT", None), + patch("rlm.clients.gemini.DEFAULT_LOCATION", None), + ): with pytest.raises(ValueError, match="Gemini API key is required"): GeminiClient(api_key=None) @@ -115,7 +120,7 @@ def test_completion_requires_model(self): with pytest.raises(ValueError, match="Model name is required"): client.completion("Hello") - def test_completion_with_mocked_response(self): + def test_completion_mocked(self): """Test completion with mocked API response.""" mock_response = MagicMock() mock_response.text = "Hello from Gemini!" @@ -133,8 +138,127 @@ def test_completion_with_mocked_response(self): assert result == "Hello from Gemini!" assert client.model_call_counts["gemini-2.5-flash"] == 1 assert client.model_input_tokens["gemini-2.5-flash"] == 10 + assert client.model_input_tokens["gemini-2.5-flash"] == 10 + assert client.model_output_tokens["gemini-2.5-flash"] == 5 + + @pytest.mark.asyncio + async def test_acompletion_mocked(self): + """Test async completion with mocked API response.""" + mock_response = MagicMock() + mock_response.text = "Hello from async Gemini!" + mock_response.usage_metadata.prompt_token_count = 10 + mock_response.usage_metadata.candidates_token_count = 5 + + async def mock_call(*args, **kwargs): + return mock_response + + with patch("rlm.clients.gemini.genai.Client") as mock_client_class: + mock_client = MagicMock() + mock_client.aio.models.generate_content.side_effect = mock_call + mock_client_class.return_value = mock_client + + client = GeminiClient(api_key="test-key", model_name="gemini-2.5-flash") + result = await client.acompletion("Hello") + + assert result == "Hello from async Gemini!" + assert client.model_call_counts["gemini-2.5-flash"] == 1 + assert client.model_input_tokens["gemini-2.5-flash"] == 10 + assert client.model_output_tokens["gemini-2.5-flash"] == 5 + + @pytest.mark.asyncio + async def test_acompletion_stream_mocked(self): + """Test acompletion_stream with mocked API response.""" + mock_chunk = MagicMock() + mock_chunk.text = "Hello" + mock_chunk.usage_metadata.prompt_token_count = 10 + mock_chunk.usage_metadata.candidates_token_count = 5 + + async def mock_stream(): + yield mock_chunk + + async def mock_call(*args, **kwargs): + return mock_stream() + + with patch("rlm.clients.gemini.genai.Client") as mock_client_class: + mock_client = MagicMock() + # The async client method is a coroutine that returns an async iterator + mock_client.aio.models.generate_content_stream.side_effect = mock_call + mock_client_class.return_value = mock_client + + client = GeminiClient(api_key="test-key", model_name="gemini-2.5-flash") + chunks = [] + async for chunk in client.acompletion_stream("Hello"): + chunks.append(chunk) + + assert chunks == ["Hello"] + assert client.model_call_counts["gemini-2.5-flash"] == 1 + assert client.model_input_tokens["gemini-2.5-flash"] == 10 + assert client.model_output_tokens["gemini-2.5-flash"] == 5 + + def test_completion_stream_mocked(self): + """Test completion_stream with mocked API response.""" + mock_chunk = MagicMock() + mock_chunk.text = "Hello" + mock_chunk.usage_metadata.prompt_token_count = 10 + mock_chunk.usage_metadata.candidates_token_count = 5 + + def mock_stream(): + yield mock_chunk + + with patch("rlm.clients.gemini.genai.Client") as mock_client_class: + mock_client = MagicMock() + mock_client.models.generate_content_stream.return_value = mock_stream() + mock_client_class.return_value = mock_client + + client = GeminiClient(api_key="test-key", model_name="gemini-2.5-flash") + chunks = list(client.completion_stream("Hello")) + + assert chunks == ["Hello"] + assert client.model_call_counts["gemini-2.5-flash"] == 1 + assert client.model_input_tokens["gemini-2.5-flash"] == 10 assert client.model_output_tokens["gemini-2.5-flash"] == 5 + def test_init_vertexai_mocked(self): + """Test client initialization with Vertex AI.""" + with patch("rlm.clients.gemini.genai.Client") as mock_client_class: + client = GeminiClient( + vertexai=True, + project="test-project", + location="us-central1", + model_name="gemini-2.5-flash", + ) + mock_client_class.assert_called_once() + args, kwargs = mock_client_class.call_args + assert kwargs["vertexai"] is True + assert kwargs["project"] == "test-project" + assert kwargs["location"] == "us-central1" + assert client.model_name == "gemini-2.5-flash" + + def test_safety_settings_mocked(self): + """Test that safety settings are correctly passed to the config.""" + safety_settings = [ + {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_LOW_AND_ABOVE"} + ] + mock_response = MagicMock() + mock_response.text = "Safe response" + mock_response.usage_metadata.prompt_token_count = 5 + mock_response.usage_metadata.candidates_token_count = 5 + + with patch("rlm.clients.gemini.genai.Client") as mock_client_class: + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + mock_client_class.return_value = mock_client + + client = GeminiClient(api_key="test-key", safety_settings=safety_settings) + assert client.safety_settings == safety_settings + + client.completion("Hello") + + # Verify GenerateContentConfig was called with safety_settings + call_args = mock_client.models.generate_content.call_args + config = call_args.kwargs["config"] + assert config.safety_settings == safety_settings + class TestGeminiClientIntegration: """Integration tests that require a real API key.""" @@ -172,12 +296,52 @@ def test_message_list_completion(self): not os.environ.get("GEMINI_API_KEY"), reason="GEMINI_API_KEY not set", ) + @pytest.mark.asyncio async def test_async_completion(self): """Test async completion.""" client = GeminiClient(model_name="gemini-2.5-flash") result = await client.acompletion("What is 3+3? Reply with just the number.") assert "6" in result + @pytest.mark.skipif( + not os.environ.get("GEMINI_API_KEY"), + reason="GEMINI_API_KEY not set", + ) + def test_streaming_completion(self): + """Test streaming completion.""" + client = GeminiClient(model_name="gemini-2.5-flash") + chunks = list(client.completion_stream("Count to 3. Reply with just numbers.")) + result = "".join(chunks) + assert "1" in result + assert "2" in result + assert "3" in result + + # Verify usage was tracked + usage = client.get_usage_summary() + assert "gemini-2.5-flash" in usage.model_usage_summaries + assert usage.model_usage_summaries["gemini-2.5-flash"].total_calls == 1 + + @pytest.mark.skipif( + not os.environ.get("GEMINI_API_KEY"), + reason="GEMINI_API_KEY not set", + ) + @pytest.mark.asyncio + async def test_async_streaming_completion(self): + """Test async streaming completion.""" + client = GeminiClient(model_name="gemini-2.5-flash") + chunks = [] + async for chunk in client.acompletion_stream("Count from 4 to 6. Reply with just numbers."): + chunks.append(chunk) + result = "".join(chunks) + assert "4" in result + assert "5" in result + assert "6" in result + + # Verify usage was tracked + usage = client.get_usage_summary() + assert "gemini-2.5-flash" in usage.model_usage_summaries + assert usage.model_usage_summaries["gemini-2.5-flash"].total_calls == 1 + if __name__ == "__main__": # Run integration tests directly @@ -186,4 +350,10 @@ async def test_async_completion(self): test.test_simple_completion() print("Testing message list completion...") test.test_message_list_completion() + print("Testing async completion...") + asyncio.run(test.test_async_completion()) + print("Testing streaming completion...") + test.test_streaming_completion() + print("Testing async streaming completion...") + asyncio.run(test.test_async_streaming_completion()) print("All integration tests passed!") From ff7cf57e8f3068c18c295f6e8546fb51a8f3f7ce Mon Sep 17 00:00:00 2001 From: Karl Weinmeister Date: Mon, 5 Jan 2026 19:44:35 -0600 Subject: [PATCH 2/2] fix: remove duplicate test assertion --- tests/clients/test_gemini.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/clients/test_gemini.py b/tests/clients/test_gemini.py index 91944084..9ed401c2 100644 --- a/tests/clients/test_gemini.py +++ b/tests/clients/test_gemini.py @@ -120,7 +120,7 @@ def test_completion_requires_model(self): with pytest.raises(ValueError, match="Model name is required"): client.completion("Hello") - def test_completion_mocked(self): + def test_completion_with_mocked_response(self): """Test completion with mocked API response.""" mock_response = MagicMock() mock_response.text = "Hello from Gemini!" @@ -138,7 +138,6 @@ def test_completion_mocked(self): assert result == "Hello from Gemini!" assert client.model_call_counts["gemini-2.5-flash"] == 1 assert client.model_input_tokens["gemini-2.5-flash"] == 10 - assert client.model_input_tokens["gemini-2.5-flash"] == 10 assert client.model_output_tokens["gemini-2.5-flash"] == 5 @pytest.mark.asyncio