diff --git a/tests/clients/test_anthropic.py b/tests/clients/test_anthropic.py new file mode 100644 index 00000000..fccbf38c --- /dev/null +++ b/tests/clients/test_anthropic.py @@ -0,0 +1,504 @@ +"""Tests for the Anthropic client.""" + +import os +from unittest.mock import MagicMock, patch + +import pytest +from dotenv import load_dotenv + +from rlm.clients.anthropic import AnthropicClient +from rlm.core.types import ModelUsageSummary, UsageSummary + +load_dotenv() + + +class TestAnthropicClientUnit: + """Unit tests that don't require API calls.""" + + def test_init_with_api_key(self): + """Test client initialization with explicit API key.""" + with patch("rlm.clients.anthropic.anthropic.Anthropic"): + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + client = AnthropicClient(api_key="test-key", model_name="claude-sonnet-4-20250514") + assert client.model_name == "claude-sonnet-4-20250514" + assert client.max_tokens == 32768 + + def test_init_with_custom_max_tokens(self): + """Test client initialization with custom max_tokens.""" + with patch("rlm.clients.anthropic.anthropic.Anthropic"): + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + client = AnthropicClient( + api_key="test-key", + model_name="claude-sonnet-4-20250514", + max_tokens=4096, + ) + assert client.max_tokens == 4096 + + def test_usage_tracking_initialization(self): + """Test that usage tracking is properly initialized.""" + with patch("rlm.clients.anthropic.anthropic.Anthropic"): + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + client = AnthropicClient(api_key="test-key") + assert dict(client.model_call_counts) == {} + assert dict(client.model_input_tokens) == {} + assert dict(client.model_output_tokens) == {} + assert dict(client.model_total_tokens) == {} + + def test_get_usage_summary_empty(self): + """Test usage summary when no calls have been made.""" + with patch("rlm.clients.anthropic.anthropic.Anthropic"): + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + client = AnthropicClient(api_key="test-key") + summary = client.get_usage_summary() + assert isinstance(summary, UsageSummary) + assert summary.model_usage_summaries == {} + + def test_get_last_usage(self): + """Test last usage returns correct format.""" + with patch("rlm.clients.anthropic.anthropic.Anthropic"): + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + client = AnthropicClient(api_key="test-key") + client.last_prompt_tokens = 100 + client.last_completion_tokens = 50 + usage = client.get_last_usage() + assert isinstance(usage, ModelUsageSummary) + assert usage.total_calls == 1 + assert usage.total_input_tokens == 100 + assert usage.total_output_tokens == 50 + + def test_completion_requires_model(self): + """Test completion raises when no model specified.""" + with patch("rlm.clients.anthropic.anthropic.Anthropic"): + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + client = AnthropicClient(api_key="test-key", model_name=None) + with pytest.raises(ValueError, match="Model name is required"): + client.completion("Hello") + + def test_completion_invalid_prompt_type(self): + """Test completion raises on invalid prompt type.""" + with patch("rlm.clients.anthropic.anthropic.Anthropic"): + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + client = AnthropicClient(api_key="test-key", model_name="claude-sonnet-4-20250514") + with pytest.raises(ValueError, match="Invalid prompt type"): + client.completion(12345) + + +class TestAnthropicPrepareMessages: + """Tests for the _prepare_messages method.""" + + def test_prepare_messages_string_prompt(self): + """Test _prepare_messages converts string to user message.""" + with patch("rlm.clients.anthropic.anthropic.Anthropic"): + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + client = AnthropicClient(api_key="test-key") + messages, system = client._prepare_messages("Hello world") + assert messages == [{"role": "user", "content": "Hello world"}] + assert system is None + + def test_prepare_messages_list_without_system(self): + """Test _prepare_messages with message list, no system message.""" + with patch("rlm.clients.anthropic.anthropic.Anthropic"): + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + client = AnthropicClient(api_key="test-key") + input_messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + {"role": "user", "content": "How are you?"}, + ] + messages, system = client._prepare_messages(input_messages) + assert messages == input_messages + assert system is None + + def test_prepare_messages_extracts_system_message(self): + """Test _prepare_messages extracts system message from list.""" + with patch("rlm.clients.anthropic.anthropic.Anthropic"): + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + client = AnthropicClient(api_key="test-key") + input_messages = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello"}, + ] + messages, system = client._prepare_messages(input_messages) + assert system == "You are a helpful assistant" + assert messages == [{"role": "user", "content": "Hello"}] + + def test_prepare_messages_preserves_non_system_messages(self): + """Test _prepare_messages preserves all non-system messages.""" + with patch("rlm.clients.anthropic.anthropic.Anthropic"): + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + client = AnthropicClient(api_key="test-key") + input_messages = [ + {"role": "system", "content": "Be helpful"}, + {"role": "user", "content": "Question 1"}, + {"role": "assistant", "content": "Answer 1"}, + {"role": "user", "content": "Question 2"}, + ] + messages, system = client._prepare_messages(input_messages) + assert system == "Be helpful" + assert len(messages) == 3 + assert messages[0] == {"role": "user", "content": "Question 1"} + assert messages[1] == {"role": "assistant", "content": "Answer 1"} + assert messages[2] == {"role": "user", "content": "Question 2"} + + def test_prepare_messages_invalid_type_raises(self): + """Test _prepare_messages raises on invalid input type.""" + with patch("rlm.clients.anthropic.anthropic.Anthropic"): + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + client = AnthropicClient(api_key="test-key") + with pytest.raises(ValueError, match="Invalid prompt type"): + client._prepare_messages(12345) + + +class TestAnthropicCompletion: + """Tests for the completion method with mocked API.""" + + def test_completion_with_string_prompt(self): + """Test completion converts string prompt to messages.""" + mock_response = MagicMock() + mock_response.content = [MagicMock()] + mock_response.content[0].text = "Hello from Anthropic!" + mock_response.usage.input_tokens = 10 + mock_response.usage.output_tokens = 5 + + with patch("rlm.clients.anthropic.anthropic.Anthropic") as mock_anthropic_class: + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + mock_client = MagicMock() + mock_client.messages.create.return_value = mock_response + mock_anthropic_class.return_value = mock_client + + client = AnthropicClient(api_key="test-key", model_name="claude-sonnet-4-20250514") + result = client.completion("Hello") + + assert result == "Hello from Anthropic!" + call_args = mock_client.messages.create.call_args + assert call_args.kwargs["messages"] == [{"role": "user", "content": "Hello"}] + assert call_args.kwargs["model"] == "claude-sonnet-4-20250514" + assert "system" not in call_args.kwargs + + def test_completion_with_message_list(self): + """Test completion with message list format.""" + mock_response = MagicMock() + mock_response.content = [MagicMock()] + mock_response.content[0].text = "Response" + mock_response.usage.input_tokens = 20 + mock_response.usage.output_tokens = 10 + + with patch("rlm.clients.anthropic.anthropic.Anthropic") as mock_anthropic_class: + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + mock_client = MagicMock() + mock_client.messages.create.return_value = mock_response + mock_anthropic_class.return_value = mock_client + + client = AnthropicClient(api_key="test-key", model_name="claude-sonnet-4-20250514") + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + {"role": "user", "content": "How are you?"}, + ] + result = client.completion(messages) + + assert result == "Response" + call_args = mock_client.messages.create.call_args + assert call_args.kwargs["messages"] == messages + + def test_completion_with_system_message(self): + """Test completion passes system message to API.""" + mock_response = MagicMock() + mock_response.content = [MagicMock()] + mock_response.content[0].text = "I am helpful" + mock_response.usage.input_tokens = 15 + mock_response.usage.output_tokens = 8 + + with patch("rlm.clients.anthropic.anthropic.Anthropic") as mock_anthropic_class: + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + mock_client = MagicMock() + mock_client.messages.create.return_value = mock_response + mock_anthropic_class.return_value = mock_client + + client = AnthropicClient(api_key="test-key", model_name="claude-sonnet-4-20250514") + messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + ] + result = client.completion(messages) + + assert result == "I am helpful" + call_args = mock_client.messages.create.call_args + assert call_args.kwargs["system"] == "You are helpful" + assert call_args.kwargs["messages"] == [{"role": "user", "content": "Hello"}] + + def test_completion_with_model_override(self): + """Test completion uses provided model over default.""" + mock_response = MagicMock() + mock_response.content = [MagicMock()] + mock_response.content[0].text = "Response" + mock_response.usage.input_tokens = 10 + mock_response.usage.output_tokens = 5 + + with patch("rlm.clients.anthropic.anthropic.Anthropic") as mock_anthropic_class: + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + mock_client = MagicMock() + mock_client.messages.create.return_value = mock_response + mock_anthropic_class.return_value = mock_client + + client = AnthropicClient(api_key="test-key", model_name="claude-sonnet-4-20250514") + client.completion("Test", model="claude-3-haiku-20240307") + + call_args = mock_client.messages.create.call_args + assert call_args.kwargs["model"] == "claude-3-haiku-20240307" + + +class TestAnthropicUsageTracking: + """Tests for usage tracking functionality.""" + + def test_completion_tracks_usage(self): + """Test completion properly tracks usage statistics.""" + mock_response = MagicMock() + mock_response.content = [MagicMock()] + mock_response.content[0].text = "Response" + mock_response.usage.input_tokens = 100 + mock_response.usage.output_tokens = 50 + + with patch("rlm.clients.anthropic.anthropic.Anthropic") as mock_anthropic_class: + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + mock_client = MagicMock() + mock_client.messages.create.return_value = mock_response + mock_anthropic_class.return_value = mock_client + + client = AnthropicClient(api_key="test-key", model_name="claude-sonnet-4-20250514") + client.completion("Test") + + assert client.model_call_counts["claude-sonnet-4-20250514"] == 1 + assert client.model_input_tokens["claude-sonnet-4-20250514"] == 100 + assert client.model_output_tokens["claude-sonnet-4-20250514"] == 50 + assert client.model_total_tokens["claude-sonnet-4-20250514"] == 150 + assert client.last_prompt_tokens == 100 + assert client.last_completion_tokens == 50 + + def test_completion_multiple_calls_accumulate_usage(self): + """Test that multiple completions accumulate usage correctly.""" + mock_response = MagicMock() + mock_response.content = [MagicMock()] + mock_response.content[0].text = "Response" + mock_response.usage.input_tokens = 10 + mock_response.usage.output_tokens = 5 + + with patch("rlm.clients.anthropic.anthropic.Anthropic") as mock_anthropic_class: + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + mock_client = MagicMock() + mock_client.messages.create.return_value = mock_response + mock_anthropic_class.return_value = mock_client + + client = AnthropicClient(api_key="test-key", model_name="claude-sonnet-4-20250514") + client.completion("Test 1") + client.completion("Test 2") + client.completion("Test 3") + + assert client.model_call_counts["claude-sonnet-4-20250514"] == 3 + assert client.model_input_tokens["claude-sonnet-4-20250514"] == 30 + assert client.model_output_tokens["claude-sonnet-4-20250514"] == 15 + assert client.model_total_tokens["claude-sonnet-4-20250514"] == 45 + + def test_completion_with_different_models_tracks_separately(self): + """Test completion tracks usage separately for different models.""" + mock_response = MagicMock() + mock_response.content = [MagicMock()] + mock_response.content[0].text = "Response" + mock_response.usage.input_tokens = 10 + mock_response.usage.output_tokens = 5 + + with patch("rlm.clients.anthropic.anthropic.Anthropic") as mock_anthropic_class: + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + mock_client = MagicMock() + mock_client.messages.create.return_value = mock_response + mock_anthropic_class.return_value = mock_client + + client = AnthropicClient(api_key="test-key", model_name="claude-sonnet-4-20250514") + client.completion("Test 1") + client.completion("Test 2", model="claude-3-haiku-20240307") + client.completion("Test 3", model="claude-sonnet-4-20250514") + + assert client.model_call_counts["claude-sonnet-4-20250514"] == 2 + assert client.model_call_counts["claude-3-haiku-20240307"] == 1 + + def test_get_usage_summary_after_calls(self): + """Test usage summary returns correct data after calls.""" + mock_response = MagicMock() + mock_response.content = [MagicMock()] + mock_response.content[0].text = "Response" + mock_response.usage.input_tokens = 100 + mock_response.usage.output_tokens = 50 + + with patch("rlm.clients.anthropic.anthropic.Anthropic") as mock_anthropic_class: + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + mock_client = MagicMock() + mock_client.messages.create.return_value = mock_response + mock_anthropic_class.return_value = mock_client + + client = AnthropicClient(api_key="test-key", model_name="claude-sonnet-4-20250514") + client.completion("Test") + + summary = client.get_usage_summary() + assert "claude-sonnet-4-20250514" in summary.model_usage_summaries + model_summary = summary.model_usage_summaries["claude-sonnet-4-20250514"] + assert model_summary.total_calls == 1 + assert model_summary.total_input_tokens == 100 + assert model_summary.total_output_tokens == 50 + + +class TestAnthropicClientAsync: + """Tests for async completion method.""" + + def test_acompletion_with_string_prompt(self): + """Test async completion with string prompt.""" + import asyncio + + mock_response = MagicMock() + mock_response.content = [MagicMock()] + mock_response.content[0].text = "Async response" + mock_response.usage.input_tokens = 10 + mock_response.usage.output_tokens = 5 + + async def run_test(): + with patch("rlm.clients.anthropic.anthropic.Anthropic"): + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic") as mock_async_class: + mock_async_client = MagicMock() + + async def mock_create(**kwargs): + return mock_response + + mock_async_client.messages.create = mock_create + mock_async_class.return_value = mock_async_client + + client = AnthropicClient( + api_key="test-key", model_name="claude-sonnet-4-20250514" + ) + result = await client.acompletion("Hello") + + assert result == "Async response" + + asyncio.run(run_test()) + + def test_acompletion_requires_model(self): + """Test async completion raises when no model specified.""" + import asyncio + + async def run_test(): + with patch("rlm.clients.anthropic.anthropic.Anthropic"): + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + client = AnthropicClient(api_key="test-key", model_name=None) + with pytest.raises(ValueError, match="Model name is required"): + await client.acompletion("Hello") + + asyncio.run(run_test()) + + def test_acompletion_invalid_prompt_type(self): + """Test async completion raises on invalid prompt type.""" + import asyncio + + async def run_test(): + with patch("rlm.clients.anthropic.anthropic.Anthropic"): + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic"): + client = AnthropicClient( + api_key="test-key", model_name="claude-sonnet-4-20250514" + ) + with pytest.raises(ValueError, match="Invalid prompt type"): + await client.acompletion(12345) + + asyncio.run(run_test()) + + def test_acompletion_with_system_message(self): + """Test async completion with system message.""" + import asyncio + + mock_response = MagicMock() + mock_response.content = [MagicMock()] + mock_response.content[0].text = "I am helpful" + mock_response.usage.input_tokens = 15 + mock_response.usage.output_tokens = 8 + + async def run_test(): + with patch("rlm.clients.anthropic.anthropic.Anthropic"): + with patch("rlm.clients.anthropic.anthropic.AsyncAnthropic") as mock_async_class: + mock_async_client = MagicMock() + captured_kwargs = {} + + async def mock_create(**kwargs): + captured_kwargs.update(kwargs) + return mock_response + + mock_async_client.messages.create = mock_create + mock_async_class.return_value = mock_async_client + + client = AnthropicClient( + api_key="test-key", model_name="claude-sonnet-4-20250514" + ) + messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + ] + result = await client.acompletion(messages) + + assert result == "I am helpful" + assert captured_kwargs["system"] == "You are helpful" + assert captured_kwargs["messages"] == [{"role": "user", "content": "Hello"}] + + asyncio.run(run_test()) + + +class TestAnthropicClientIntegration: + """Integration tests that require a real API key.""" + + @pytest.mark.skipif( + not os.environ.get("ANTHROPIC_API_KEY"), + reason="ANTHROPIC_API_KEY not set", + ) + def test_simple_completion(self): + """Test a simple completion with real API.""" + client = AnthropicClient(model_name="claude-3-haiku-20240307") + result = client.completion("What is 2+2? Reply with just the number.") + assert "4" in result + + # Verify usage was tracked + usage = client.get_usage_summary() + assert "claude-3-haiku-20240307" in usage.model_usage_summaries + assert usage.model_usage_summaries["claude-3-haiku-20240307"].total_calls == 1 + + @pytest.mark.skipif( + not os.environ.get("ANTHROPIC_API_KEY"), + reason="ANTHROPIC_API_KEY not set", + ) + def test_message_list_completion(self): + """Test completion with message list format.""" + client = AnthropicClient(model_name="claude-3-haiku-20240307") + messages = [ + {"role": "system", "content": "You are a helpful math tutor."}, + {"role": "user", "content": "What is 5 * 5? Reply with just the number."}, + ] + result = client.completion(messages) + assert "25" in result + + @pytest.mark.skipif( + not os.environ.get("ANTHROPIC_API_KEY"), + reason="ANTHROPIC_API_KEY not set", + ) + def test_async_completion(self): + """Test async completion.""" + import asyncio + + async def run_test(): + client = AnthropicClient(model_name="claude-3-haiku-20240307") + result = await client.acompletion("What is 3+3? Reply with just the number.") + assert "6" in result + + asyncio.run(run_test()) + + +if __name__ == "__main__": + # Run integration tests directly + test = TestAnthropicClientIntegration() + print("Testing simple completion...") + test.test_simple_completion() + print("Testing message list completion...") + test.test_message_list_completion() + print("All integration tests passed!") diff --git a/tests/clients/test_openai.py b/tests/clients/test_openai.py new file mode 100644 index 00000000..9bf79140 --- /dev/null +++ b/tests/clients/test_openai.py @@ -0,0 +1,430 @@ +"""Tests for the OpenAI client.""" + +import os +from unittest.mock import MagicMock, patch + +import pytest +from dotenv import load_dotenv + +from rlm.clients.openai import ( + DEFAULT_PRIME_INTELLECT_BASE_URL, + OpenAIClient, +) +from rlm.core.types import ModelUsageSummary, UsageSummary + +load_dotenv() + + +class TestOpenAIClientUnit: + """Unit tests that don't require API calls.""" + + def test_init_with_api_key(self): + """Test client initialization with explicit API key.""" + with patch("rlm.clients.openai.openai.OpenAI"): + with patch("rlm.clients.openai.openai.AsyncOpenAI"): + client = OpenAIClient(api_key="test-key", model_name="gpt-4o") + assert client.model_name == "gpt-4o" + + def test_init_with_base_url(self): + """Test client initialization with custom base URL.""" + with patch("rlm.clients.openai.openai.OpenAI") as mock_openai: + with patch("rlm.clients.openai.openai.AsyncOpenAI") as mock_async: + client = OpenAIClient( + api_key="test-key", + model_name="gpt-4o", + base_url="https://custom.api.com/v1", + ) + mock_openai.assert_called_once_with( + api_key="test-key", base_url="https://custom.api.com/v1" + ) + mock_async.assert_called_once_with( + api_key="test-key", base_url="https://custom.api.com/v1" + ) + assert client.model_name == "gpt-4o" + + def test_init_auto_selects_openrouter_key(self): + """Test client auto-selects OpenRouter API key for OpenRouter base URL.""" + with patch("rlm.clients.openai.openai.OpenAI") as mock_openai: + with patch("rlm.clients.openai.openai.AsyncOpenAI"): + with patch("rlm.clients.openai.DEFAULT_OPENROUTER_API_KEY", "openrouter-key"): + OpenAIClient( + api_key=None, + model_name="gpt-4o", + base_url="https://openrouter.ai/api/v1", + ) + mock_openai.assert_called_once_with( + api_key="openrouter-key", base_url="https://openrouter.ai/api/v1" + ) + + def test_init_auto_selects_vercel_key(self): + """Test client auto-selects Vercel API key for AI Gateway base URL.""" + with patch("rlm.clients.openai.openai.OpenAI") as mock_openai: + with patch("rlm.clients.openai.openai.AsyncOpenAI"): + with patch("rlm.clients.openai.DEFAULT_VERCEL_API_KEY", "vercel-key"): + OpenAIClient( + api_key=None, + model_name="gpt-4o", + base_url="https://ai-gateway.vercel.sh/v1", + ) + mock_openai.assert_called_once_with( + api_key="vercel-key", base_url="https://ai-gateway.vercel.sh/v1" + ) + + def test_usage_tracking_initialization(self): + """Test that usage tracking is properly initialized.""" + with patch("rlm.clients.openai.openai.OpenAI"): + with patch("rlm.clients.openai.openai.AsyncOpenAI"): + client = OpenAIClient(api_key="test-key") + assert dict(client.model_call_counts) == {} + assert dict(client.model_input_tokens) == {} + assert dict(client.model_output_tokens) == {} + assert dict(client.model_total_tokens) == {} + + def test_get_usage_summary_empty(self): + """Test usage summary when no calls have been made.""" + with patch("rlm.clients.openai.openai.OpenAI"): + with patch("rlm.clients.openai.openai.AsyncOpenAI"): + client = OpenAIClient(api_key="test-key") + summary = client.get_usage_summary() + assert isinstance(summary, UsageSummary) + assert summary.model_usage_summaries == {} + + def test_get_last_usage(self): + """Test last usage returns correct format.""" + with patch("rlm.clients.openai.openai.OpenAI"): + with patch("rlm.clients.openai.openai.AsyncOpenAI"): + client = OpenAIClient(api_key="test-key") + client.last_prompt_tokens = 100 + client.last_completion_tokens = 50 + usage = client.get_last_usage() + assert isinstance(usage, ModelUsageSummary) + assert usage.total_calls == 1 + assert usage.total_input_tokens == 100 + assert usage.total_output_tokens == 50 + + def test_completion_requires_model(self): + """Test completion raises when no model specified.""" + with patch("rlm.clients.openai.openai.OpenAI"): + with patch("rlm.clients.openai.openai.AsyncOpenAI"): + client = OpenAIClient(api_key="test-key", model_name=None) + with pytest.raises(ValueError, match="Model name is required"): + client.completion("Hello") + + def test_completion_invalid_prompt_type(self): + """Test completion raises on invalid prompt type.""" + with patch("rlm.clients.openai.openai.OpenAI"): + with patch("rlm.clients.openai.openai.AsyncOpenAI"): + client = OpenAIClient(api_key="test-key", model_name="gpt-4o") + with pytest.raises(ValueError, match="Invalid prompt type"): + client.completion(12345) + + def test_completion_with_string_prompt(self): + """Test completion converts string prompt to messages.""" + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Hello from OpenAI!" + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 5 + mock_response.usage.total_tokens = 15 + + with patch("rlm.clients.openai.openai.OpenAI") as mock_openai_class: + with patch("rlm.clients.openai.openai.AsyncOpenAI"): + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = mock_response + mock_client.base_url = "https://api.openai.com/v1" + mock_openai_class.return_value = mock_client + + client = OpenAIClient(api_key="test-key", model_name="gpt-4o") + result = client.completion("Hello") + + assert result == "Hello from OpenAI!" + # Verify the prompt was converted to messages format + call_args = mock_client.chat.completions.create.call_args + assert call_args.kwargs["messages"] == [{"role": "user", "content": "Hello"}] + + def test_completion_with_message_list(self): + """Test completion with message list format.""" + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response" + mock_response.usage.prompt_tokens = 20 + mock_response.usage.completion_tokens = 10 + mock_response.usage.total_tokens = 30 + + with patch("rlm.clients.openai.openai.OpenAI") as mock_openai_class: + with patch("rlm.clients.openai.openai.AsyncOpenAI"): + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = mock_response + mock_client.base_url = "https://api.openai.com/v1" + mock_openai_class.return_value = mock_client + + client = OpenAIClient(api_key="test-key", model_name="gpt-4o") + messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + ] + result = client.completion(messages) + + assert result == "Response" + # Verify messages were passed directly + call_args = mock_client.chat.completions.create.call_args + assert call_args.kwargs["messages"] == messages + + def test_completion_tracks_usage(self): + """Test completion properly tracks usage statistics.""" + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response" + mock_response.usage.prompt_tokens = 100 + mock_response.usage.completion_tokens = 50 + mock_response.usage.total_tokens = 150 + + with patch("rlm.clients.openai.openai.OpenAI") as mock_openai_class: + with patch("rlm.clients.openai.openai.AsyncOpenAI"): + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = mock_response + mock_client.base_url = "https://api.openai.com/v1" + mock_openai_class.return_value = mock_client + + client = OpenAIClient(api_key="test-key", model_name="gpt-4o") + client.completion("Test") + + assert client.model_call_counts["gpt-4o"] == 1 + assert client.model_input_tokens["gpt-4o"] == 100 + assert client.model_output_tokens["gpt-4o"] == 50 + assert client.model_total_tokens["gpt-4o"] == 150 + assert client.last_prompt_tokens == 100 + assert client.last_completion_tokens == 50 + + def test_completion_multiple_calls_accumulate_usage(self): + """Test that multiple completions accumulate usage correctly.""" + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response" + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 5 + mock_response.usage.total_tokens = 15 + + with patch("rlm.clients.openai.openai.OpenAI") as mock_openai_class: + with patch("rlm.clients.openai.openai.AsyncOpenAI"): + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = mock_response + mock_client.base_url = "https://api.openai.com/v1" + mock_openai_class.return_value = mock_client + + client = OpenAIClient(api_key="test-key", model_name="gpt-4o") + client.completion("Test 1") + client.completion("Test 2") + client.completion("Test 3") + + assert client.model_call_counts["gpt-4o"] == 3 + assert client.model_input_tokens["gpt-4o"] == 30 + assert client.model_output_tokens["gpt-4o"] == 15 + assert client.model_total_tokens["gpt-4o"] == 45 + + def test_get_usage_summary_after_calls(self): + """Test usage summary returns correct data after calls.""" + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response" + mock_response.usage.prompt_tokens = 100 + mock_response.usage.completion_tokens = 50 + mock_response.usage.total_tokens = 150 + + with patch("rlm.clients.openai.openai.OpenAI") as mock_openai_class: + with patch("rlm.clients.openai.openai.AsyncOpenAI"): + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = mock_response + mock_client.base_url = "https://api.openai.com/v1" + mock_openai_class.return_value = mock_client + + client = OpenAIClient(api_key="test-key", model_name="gpt-4o") + client.completion("Test") + + summary = client.get_usage_summary() + assert "gpt-4o" in summary.model_usage_summaries + model_summary = summary.model_usage_summaries["gpt-4o"] + assert model_summary.total_calls == 1 + assert model_summary.total_input_tokens == 100 + assert model_summary.total_output_tokens == 50 + + def test_track_cost_raises_on_missing_usage(self): + """Test _track_cost raises when response has no usage data.""" + with patch("rlm.clients.openai.openai.OpenAI"): + with patch("rlm.clients.openai.openai.AsyncOpenAI"): + client = OpenAIClient(api_key="test-key", model_name="gpt-4o") + mock_response = MagicMock() + mock_response.usage = None + + with pytest.raises(ValueError, match="No usage data received"): + client._track_cost(mock_response, "gpt-4o") + + def test_completion_with_prime_intellect_base_url(self): + """Test completion adds extra_body for Prime Intellect API.""" + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response" + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 5 + mock_response.usage.total_tokens = 15 + + with patch("rlm.clients.openai.openai.OpenAI") as mock_openai_class: + with patch("rlm.clients.openai.openai.AsyncOpenAI"): + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = mock_response + mock_client.base_url = DEFAULT_PRIME_INTELLECT_BASE_URL + mock_openai_class.return_value = mock_client + + client = OpenAIClient( + api_key="test-key", + model_name="test-model", + base_url=DEFAULT_PRIME_INTELLECT_BASE_URL, + ) + client.completion("Test") + + # Verify extra_body was passed + call_args = mock_client.chat.completions.create.call_args + assert call_args.kwargs["extra_body"] == {"usage": {"include": True}} + + def test_completion_with_different_models(self): + """Test completion tracks usage separately for different models.""" + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Response" + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 5 + mock_response.usage.total_tokens = 15 + + with patch("rlm.clients.openai.openai.OpenAI") as mock_openai_class: + with patch("rlm.clients.openai.openai.AsyncOpenAI"): + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = mock_response + mock_client.base_url = "https://api.openai.com/v1" + mock_openai_class.return_value = mock_client + + client = OpenAIClient(api_key="test-key", model_name="gpt-4o") + client.completion("Test 1") + client.completion("Test 2", model="gpt-4o-mini") + client.completion("Test 3", model="gpt-4o") + + assert client.model_call_counts["gpt-4o"] == 2 + assert client.model_call_counts["gpt-4o-mini"] == 1 + + +class TestOpenAIClientAsync: + """Tests for async completion method.""" + + def test_acompletion_with_string_prompt(self): + """Test async completion with string prompt.""" + import asyncio + + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Async response" + mock_response.usage.prompt_tokens = 10 + mock_response.usage.completion_tokens = 5 + mock_response.usage.total_tokens = 15 + + async def run_test(): + with patch("rlm.clients.openai.openai.OpenAI"): + with patch("rlm.clients.openai.openai.AsyncOpenAI") as mock_async_class: + mock_async_client = MagicMock() + + # Make the async create method return a coroutine + async def mock_create(**kwargs): + return mock_response + + mock_async_client.chat.completions.create = mock_create + mock_async_class.return_value = mock_async_client + + client = OpenAIClient(api_key="test-key", model_name="gpt-4o") + client.client.base_url = "https://api.openai.com/v1" + result = await client.acompletion("Hello") + + assert result == "Async response" + + asyncio.run(run_test()) + + def test_acompletion_requires_model(self): + """Test async completion raises when no model specified.""" + import asyncio + + async def run_test(): + with patch("rlm.clients.openai.openai.OpenAI"): + with patch("rlm.clients.openai.openai.AsyncOpenAI"): + client = OpenAIClient(api_key="test-key", model_name=None) + with pytest.raises(ValueError, match="Model name is required"): + await client.acompletion("Hello") + + asyncio.run(run_test()) + + def test_acompletion_invalid_prompt_type(self): + """Test async completion raises on invalid prompt type.""" + import asyncio + + async def run_test(): + with patch("rlm.clients.openai.openai.OpenAI"): + with patch("rlm.clients.openai.openai.AsyncOpenAI"): + client = OpenAIClient(api_key="test-key", model_name="gpt-4o") + with pytest.raises(ValueError, match="Invalid prompt type"): + await client.acompletion(12345) + + asyncio.run(run_test()) + + +class TestOpenAIClientIntegration: + """Integration tests that require a real API key.""" + + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY"), + reason="OPENAI_API_KEY not set", + ) + def test_simple_completion(self): + """Test a simple completion with real API.""" + client = OpenAIClient(model_name="gpt-4o-mini") + result = client.completion("What is 2+2? Reply with just the number.") + assert "4" in result + + # Verify usage was tracked + usage = client.get_usage_summary() + assert "gpt-4o-mini" in usage.model_usage_summaries + assert usage.model_usage_summaries["gpt-4o-mini"].total_calls == 1 + + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY"), + reason="OPENAI_API_KEY not set", + ) + def test_message_list_completion(self): + """Test completion with message list format.""" + client = OpenAIClient(model_name="gpt-4o-mini") + messages = [ + {"role": "system", "content": "You are a helpful math tutor."}, + {"role": "user", "content": "What is 5 * 5? Reply with just the number."}, + ] + result = client.completion(messages) + assert "25" in result + + @pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY"), + reason="OPENAI_API_KEY not set", + ) + def test_async_completion(self): + """Test async completion.""" + import asyncio + + async def run_test(): + client = OpenAIClient(model_name="gpt-4o-mini") + result = await client.acompletion("What is 3+3? Reply with just the number.") + assert "6" in result + + asyncio.run(run_test()) + + +if __name__ == "__main__": + # Run integration tests directly + test = TestOpenAIClientIntegration() + print("Testing simple completion...") + test.test_simple_completion() + print("Testing message list completion...") + test.test_message_list_completion() + print("All integration tests passed!") diff --git a/tests/test_comms_utils.py b/tests/test_comms_utils.py new file mode 100644 index 00000000..cc4c5cf0 --- /dev/null +++ b/tests/test_comms_utils.py @@ -0,0 +1,460 @@ +"""Tests for communication utilities. + +Tests the socket protocol and message dataclasses used for +LM Handler <-> Environment subprocess communication. +""" + +import json +import struct +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from rlm.core.comms_utils import ( + LMRequest, + LMResponse, + send_lm_request, + send_lm_request_batched, + socket_recv, + socket_request, + socket_send, +) +from rlm.core.types import ModelUsageSummary, RLMChatCompletion, UsageSummary + + +def make_chat_completion(response: str) -> RLMChatCompletion: + """Helper to create a RLMChatCompletion for tests.""" + return RLMChatCompletion( + root_model="test-model", + prompt="test prompt", + response=response, + usage_summary=UsageSummary( + model_usage_summaries={ + "test-model": ModelUsageSummary( + total_calls=1, + total_input_tokens=10, + total_output_tokens=10, + ) + } + ), + execution_time=0.1, + ) + + +# LMRequest Tests + + +class TestLMRequest: + """Tests for LMRequest dataclass.""" + + def test_single_prompt_creation(self): + """Test creating a request with a single prompt.""" + request = LMRequest(prompt="Hello, world!") + assert request.prompt == "Hello, world!" + assert request.prompts is None + assert request.is_batched is False + + def test_batched_prompts_creation(self): + """Test creating a request with multiple prompts.""" + prompts = ["Prompt 1", "Prompt 2", "Prompt 3"] + request = LMRequest(prompts=prompts) + assert request.prompt is None + assert request.prompts == prompts + assert request.is_batched is True + + def test_empty_prompts_not_batched(self): + """Test that empty prompts list is not considered batched.""" + request = LMRequest(prompts=[]) + assert request.is_batched is False + + def test_model_and_depth(self): + """Test model and depth fields.""" + request = LMRequest(prompt="Test", model="gpt-4", depth=2) + assert request.model == "gpt-4" + assert request.depth == 2 + + def test_default_depth(self): + """Test default depth is 0.""" + request = LMRequest(prompt="Test") + assert request.depth == 0 + + def test_to_dict_single_prompt(self): + """Test converting single prompt request to dict.""" + request = LMRequest(prompt="Hello", model="gpt-4", depth=1) + d = request.to_dict() + assert d["prompt"] == "Hello" + assert d["model"] == "gpt-4" + assert d["depth"] == 1 + assert "prompts" not in d + + def test_to_dict_batched(self): + """Test converting batched request to dict.""" + request = LMRequest(prompts=["A", "B"], depth=0) + d = request.to_dict() + assert d["prompts"] == ["A", "B"] + assert d["depth"] == 0 + assert "prompt" not in d + assert "model" not in d + + def test_from_dict_single_prompt(self): + """Test creating from dict with single prompt.""" + data = {"prompt": "Hello", "model": "gpt-4", "depth": 1} + request = LMRequest.from_dict(data) + assert request.prompt == "Hello" + assert request.model == "gpt-4" + assert request.depth == 1 + + def test_from_dict_batched(self): + """Test creating from dict with batched prompts.""" + data = {"prompts": ["A", "B", "C"], "depth": 2} + request = LMRequest.from_dict(data) + assert request.prompts == ["A", "B", "C"] + assert request.depth == 2 + + def test_roundtrip(self): + """Test dict conversion roundtrip.""" + original = LMRequest(prompt="Test", model="gpt-4", depth=3) + restored = LMRequest.from_dict(original.to_dict()) + assert restored.prompt == original.prompt + assert restored.model == original.model + assert restored.depth == original.depth + + +# LMResponse Tests + + +class TestLMResponse: + """Tests for LMResponse dataclass.""" + + def test_success_response(self): + """Test creating a successful response.""" + completion = make_chat_completion("Hello!") + response = LMResponse.success_response(completion) + assert response.success is True + assert response.error is None + assert response.chat_completion.response == "Hello!" + + def test_error_response(self): + """Test creating an error response.""" + response = LMResponse.error_response("Something went wrong") + assert response.success is False + assert response.error == "Something went wrong" + assert response.chat_completion is None + + def test_batched_success_response(self): + """Test creating a batched successful response.""" + completions = [ + make_chat_completion("Response 1"), + make_chat_completion("Response 2"), + ] + response = LMResponse.batched_success_response(completions) + assert response.success is True + assert response.is_batched is True + assert len(response.chat_completions) == 2 + + def test_is_batched_property(self): + """Test is_batched property.""" + single = LMResponse.success_response(make_chat_completion("Test")) + assert single.is_batched is False + + batched = LMResponse.batched_success_response([make_chat_completion("Test")]) + assert batched.is_batched is True + + def test_to_dict_success(self): + """Test converting successful response to dict.""" + completion = make_chat_completion("Hello!") + response = LMResponse.success_response(completion) + d = response.to_dict() + assert d["error"] is None + assert d["chat_completion"]["response"] == "Hello!" + assert d["chat_completions"] is None + + def test_to_dict_error(self): + """Test converting error response to dict.""" + response = LMResponse.error_response("Failed") + d = response.to_dict() + assert d["error"] == "Failed" + assert d["chat_completion"] is None + + def test_from_dict_success(self): + """Test creating from dict with successful response.""" + data = { + "chat_completion": { + "root_model": "test-model", + "prompt": "test", + "response": "Hello!", + "usage_summary": {"model_usage_summaries": {}}, + "execution_time": 0.1, + }, + "error": None, + "chat_completions": None, + } + response = LMResponse.from_dict(data) + assert response.success is True + assert response.chat_completion.response == "Hello!" + + def test_from_dict_error(self): + """Test creating from dict with error.""" + data = {"error": "Something failed", "chat_completion": None} + response = LMResponse.from_dict(data) + assert response.success is False + assert response.error == "Something failed" + + def test_empty_response_is_error(self): + """Test that response with no completion or error produces error dict.""" + response = LMResponse() + d = response.to_dict() + assert "No chat completion or error provided" in d["error"] + + +# Socket Protocol Tests + + +class TestSocketProtocol: + """Tests for socket send/recv protocol.""" + + def test_socket_send_format(self): + """Test that socket_send uses correct protocol format.""" + mock_sock = Mock() + data = {"message": "Hello"} + + socket_send(mock_sock, data) + + # Verify sendall was called once + mock_sock.sendall.assert_called_once() + + # Get the sent bytes + sent_bytes = mock_sock.sendall.call_args[0][0] + + # First 4 bytes should be big-endian length + length = struct.unpack(">I", sent_bytes[:4])[0] + payload = sent_bytes[4:] + + assert length == len(payload) + assert json.loads(payload.decode("utf-8")) == data + + def test_socket_recv_success(self): + """Test receiving a valid message.""" + data = {"response": "World"} + payload = json.dumps(data).encode("utf-8") + length_prefix = struct.pack(">I", len(payload)) + + mock_sock = Mock() + # First call returns length, second returns payload + mock_sock.recv.side_effect = [length_prefix, payload] + + result = socket_recv(mock_sock) + assert result == data + + def test_socket_recv_empty_returns_empty_dict(self): + """Test that empty recv (connection closed) returns empty dict.""" + mock_sock = Mock() + mock_sock.recv.return_value = b"" + + result = socket_recv(mock_sock) + assert result == {} + + def test_socket_recv_partial_payload(self): + """Test receiving payload in multiple chunks.""" + data = {"large": "x" * 1000} + payload = json.dumps(data).encode("utf-8") + length_prefix = struct.pack(">I", len(payload)) + + mock_sock = Mock() + # Simulate receiving in chunks + chunk1 = payload[:500] + chunk2 = payload[500:] + mock_sock.recv.side_effect = [length_prefix, chunk1, chunk2] + + result = socket_recv(mock_sock) + assert result == data + + def test_socket_recv_connection_closed_mid_message(self): + """Test that connection closing mid-message raises error.""" + mock_sock = Mock() + length_prefix = struct.pack(">I", 1000) # Expect 1000 bytes + mock_sock.recv.side_effect = [length_prefix, b"short", b""] # Connection closes + + with pytest.raises(ConnectionError, match="Connection closed before message complete"): + socket_recv(mock_sock) + + +# socket_request Tests + + +class TestSocketRequest: + """Tests for socket_request function.""" + + def test_socket_request_success(self): + """Test successful request/response cycle.""" + request_data = {"prompt": "Hello"} + response_data = {"content": "World"} + + with patch("rlm.core.comms_utils.socket.socket") as mock_socket_class: + mock_sock = MagicMock() + mock_socket_class.return_value.__enter__.return_value = mock_sock + + # Set up recv to return response + payload = json.dumps(response_data).encode("utf-8") + length_prefix = struct.pack(">I", len(payload)) + mock_sock.recv.side_effect = [length_prefix, payload] + + result = socket_request(("localhost", 5000), request_data) + + # Verify connection was made + mock_sock.connect.assert_called_once_with(("localhost", 5000)) + # Verify timeout was set + mock_sock.settimeout.assert_called_once_with(300) + # Verify response + assert result == response_data + + +# send_lm_request Tests + + +class TestSendLMRequest: + """Tests for send_lm_request function.""" + + def test_send_lm_request_success(self): + """Test successful LM request.""" + with patch("rlm.core.comms_utils.socket_request") as mock_request: + mock_request.return_value = { + "chat_completion": { + "root_model": "test", + "prompt": "test", + "response": "Hello!", + "usage_summary": {"model_usage_summaries": {}}, + "execution_time": 0.1, + }, + "error": None, + } + + request = LMRequest(prompt="Test") + response = send_lm_request(("localhost", 5000), request) + + assert response.success is True + assert response.chat_completion.response == "Hello!" + + def test_send_lm_request_error(self): + """Test LM request that returns error.""" + with patch("rlm.core.comms_utils.socket_request") as mock_request: + mock_request.return_value = {"error": "Model overloaded"} + + request = LMRequest(prompt="Test") + response = send_lm_request(("localhost", 5000), request) + + assert response.success is False + assert response.error == "Model overloaded" + + def test_send_lm_request_exception(self): + """Test LM request that raises exception.""" + with patch("rlm.core.comms_utils.socket_request") as mock_request: + mock_request.side_effect = ConnectionError("Connection refused") + + request = LMRequest(prompt="Test") + response = send_lm_request(("localhost", 5000), request) + + assert response.success is False + assert "Connection refused" in response.error + + def test_send_lm_request_depth_override(self): + """Test that depth parameter overrides request depth.""" + with patch("rlm.core.comms_utils.socket_request") as mock_request: + mock_request.return_value = { + "chat_completion": { + "root_model": "test", + "prompt": "test", + "response": "Test", + "usage_summary": {"model_usage_summaries": {}}, + "execution_time": 0.1, + }, + "error": None, + } + + request = LMRequest(prompt="Test", depth=0) + send_lm_request(("localhost", 5000), request, depth=5) + + # Verify the request was modified + assert request.depth == 5 + + +# send_lm_request_batched Tests + + +class TestSendLMRequestBatched: + """Tests for send_lm_request_batched function.""" + + def test_batched_success(self): + """Test successful batched request.""" + with patch("rlm.core.comms_utils.socket_request") as mock_request: + mock_request.return_value = { + "chat_completions": [ + { + "root_model": "test", + "prompt": "test", + "response": "Response 1", + "usage_summary": {"model_usage_summaries": {}}, + "execution_time": 0.1, + }, + { + "root_model": "test", + "prompt": "test", + "response": "Response 2", + "usage_summary": {"model_usage_summaries": {}}, + "execution_time": 0.1, + }, + ], + "error": None, + } + + responses = send_lm_request_batched( + ("localhost", 5000), + prompts=["Prompt 1", "Prompt 2"], + ) + + assert len(responses) == 2 + assert all(r.success for r in responses) + assert responses[0].chat_completion.response == "Response 1" + assert responses[1].chat_completion.response == "Response 2" + + def test_batched_error(self): + """Test batched request that returns error.""" + with patch("rlm.core.comms_utils.socket_request") as mock_request: + mock_request.return_value = {"error": "Rate limited"} + + responses = send_lm_request_batched( + ("localhost", 5000), + prompts=["A", "B", "C"], + ) + + assert len(responses) == 3 + assert all(not r.success for r in responses) + assert all("Rate limited" in r.error for r in responses) + + def test_batched_exception(self): + """Test batched request that raises exception.""" + with patch("rlm.core.comms_utils.socket_request") as mock_request: + mock_request.side_effect = TimeoutError("Connection timed out") + + responses = send_lm_request_batched( + ("localhost", 5000), + prompts=["A", "B"], + ) + + assert len(responses) == 2 + assert all(not r.success for r in responses) + assert all("Connection timed out" in r.error for r in responses) + + def test_batched_no_completions_returned(self): + """Test when server returns success but no completions.""" + with patch("rlm.core.comms_utils.socket_request") as mock_request: + mock_request.return_value = {"error": None, "chat_completions": None} + + responses = send_lm_request_batched( + ("localhost", 5000), + prompts=["A", "B"], + ) + + assert len(responses) == 2 + assert all(not r.success for r in responses) + assert all("No completions returned" in r.error for r in responses) diff --git a/tests/test_lm_handler.py b/tests/test_lm_handler.py new file mode 100644 index 00000000..3ce6dd1e --- /dev/null +++ b/tests/test_lm_handler.py @@ -0,0 +1,410 @@ +"""Tests for LMHandler. + +Tests the LMHandler class which routes LLM requests from +the RLM process and environment subprocesses. +""" + +from rlm.clients.base_lm import BaseLM +from rlm.core.lm_handler import LMHandler +from rlm.core.types import ModelUsageSummary, UsageSummary +from tests.mock_lm import MockLM + +# Test Fixtures + + +class AnotherMockLM(BaseLM): + """Another mock LM for testing multiple clients.""" + + def __init__(self, model_name: str = "another-mock-model"): + super().__init__(model_name=model_name) + + def completion(self, prompt): + return f"Another mock response to: {prompt[:50]}" + + async def acompletion(self, prompt): + return self.completion(prompt) + + def get_usage_summary(self): + return UsageSummary( + model_usage_summaries={ + self.model_name: ModelUsageSummary( + total_calls=2, total_input_tokens=20, total_output_tokens=20 + ) + } + ) + + def get_last_usage(self): + return ModelUsageSummary(total_calls=1, total_input_tokens=10, total_output_tokens=10) + + +# LMHandler Initialization Tests + + +class TestLMHandlerInit: + """Tests for LMHandler initialization.""" + + def test_basic_init(self): + """Test basic initialization with default values.""" + client = MockLM() + handler = LMHandler(client) + + assert handler.default_client is client + assert handler.other_backend_client is None + assert handler.host == "127.0.0.1" + assert handler._port == 0 + + def test_init_with_custom_host_port(self): + """Test initialization with custom host and port.""" + client = MockLM() + handler = LMHandler(client, host="0.0.0.0", port=8080) + + assert handler.host == "0.0.0.0" + assert handler._port == 8080 + + def test_init_with_other_backend(self): + """Test initialization with other_backend_client.""" + default_client = MockLM() + other_client = AnotherMockLM() + handler = LMHandler(default_client, other_backend_client=other_client) + + assert handler.default_client is default_client + assert handler.other_backend_client is other_client + + def test_default_client_auto_registered(self): + """Test that default client is auto-registered by model name.""" + client = MockLM() + handler = LMHandler(client) + + assert "mock-model" in handler.clients + assert handler.clients["mock-model"] is client + + +# get_client Routing Tests + + +class TestGetClientRouting: + """Tests for get_client routing logic.""" + + def test_get_client_default_no_args(self): + """Test get_client returns default_client when no args provided.""" + client = MockLM() + handler = LMHandler(client) + + result = handler.get_client() + assert result is client + + def test_get_client_depth_0_returns_default(self): + """Test depth=0 returns default_client.""" + default_client = MockLM() + other_client = AnotherMockLM() + handler = LMHandler(default_client, other_backend_client=other_client) + + result = handler.get_client(depth=0) + assert result is default_client + + def test_get_client_depth_1_with_other_backend(self): + """Test depth=1 returns other_backend_client when available.""" + default_client = MockLM() + other_client = AnotherMockLM() + handler = LMHandler(default_client, other_backend_client=other_client) + + result = handler.get_client(depth=1) + assert result is other_client + + def test_get_client_depth_1_without_other_backend(self): + """Test depth=1 returns default_client when other_backend not set.""" + default_client = MockLM() + handler = LMHandler(default_client) + + result = handler.get_client(depth=1) + assert result is default_client + + def test_get_client_by_model_name(self): + """Test get_client returns registered client by model name.""" + default_client = MockLM() + handler = LMHandler(default_client) + + result = handler.get_client(model="mock-model") + assert result is default_client + + def test_get_client_model_overrides_depth(self): + """Test that model name lookup overrides depth routing.""" + default_client = MockLM() + other_client = AnotherMockLM() + handler = LMHandler(default_client, other_backend_client=other_client) + + # Register another client with a specific name + special_client = AnotherMockLM(model_name="special-model") + handler.register_client("special-model", special_client) + + # Even with depth=1, should return special_client when model matches + result = handler.get_client(model="special-model", depth=1) + assert result is special_client + + def test_get_client_unknown_model_falls_back(self): + """Test unknown model name falls back to depth routing.""" + default_client = MockLM() + other_client = AnotherMockLM() + handler = LMHandler(default_client, other_backend_client=other_client) + + # Unknown model with depth=1 should use other_backend + result = handler.get_client(model="unknown-model", depth=1) + assert result is other_client + + # Unknown model with depth=0 should use default + result = handler.get_client(model="unknown-model", depth=0) + assert result is default_client + + +# register_client Tests + + +class TestRegisterClient: + """Tests for register_client method.""" + + def test_register_new_client(self): + """Test registering a new client.""" + default_client = MockLM() + handler = LMHandler(default_client) + + new_client = AnotherMockLM(model_name="new-model") + handler.register_client("new-model", new_client) + + assert "new-model" in handler.clients + assert handler.clients["new-model"] is new_client + + def test_register_overwrites_existing(self): + """Test registering a client with existing name overwrites it.""" + default_client = MockLM() + handler = LMHandler(default_client) + + new_client = AnotherMockLM(model_name="mock-model") + handler.register_client("mock-model", new_client) + + assert handler.clients["mock-model"] is new_client + + +# Completion Tests + + +class TestCompletion: + """Tests for completion method using MockLM.""" + + def test_completion_default_client(self): + """Test completion uses default client.""" + client = MockLM() + handler = LMHandler(client) + + result = handler.completion("Hello, world!") + assert result == "Mock response to: Hello, world!" + + def test_completion_with_model(self): + """Test completion with specific model.""" + default_client = MockLM() + other_client = AnotherMockLM(model_name="other-model") + handler = LMHandler(default_client) + handler.register_client("other-model", other_client) + + result = handler.completion("Test prompt", model="other-model") + assert result == "Another mock response to: Test prompt" + + +# Server Start/Stop Tests + + +class TestServerLifecycle: + """Tests for server start/stop methods.""" + + def test_start_creates_server(self): + """Test start creates server and returns address.""" + client = MockLM() + handler = LMHandler(client) + + try: + address = handler.start() + assert handler._server is not None + assert handler._thread is not None + assert address[0] == "127.0.0.1" + assert address[1] > 0 # Auto-assigned port + finally: + handler.stop() + + def test_start_idempotent(self): + """Test calling start multiple times returns same address.""" + client = MockLM() + handler = LMHandler(client) + + try: + address1 = handler.start() + address2 = handler.start() + assert address1 == address2 + finally: + handler.stop() + + def test_stop_clears_server(self): + """Test stop clears server and thread.""" + client = MockLM() + handler = LMHandler(client) + + handler.start() + handler.stop() + + assert handler._server is None + assert handler._thread is None + + def test_port_property_returns_actual_port(self): + """Test port property returns actual assigned port.""" + client = MockLM() + handler = LMHandler(client, port=0) + + try: + handler.start() + # Port should be > 0 after auto-assignment + assert handler.port > 0 + finally: + handler.stop() + + def test_address_property(self): + """Test address property returns (host, port) tuple.""" + client = MockLM() + handler = LMHandler(client) + + try: + handler.start() + address = handler.address + assert address == (handler.host, handler.port) + finally: + handler.stop() + + +# Context Manager Tests + + +class TestContextManager: + """Tests for context manager protocol.""" + + def test_enter_starts_server(self): + """Test __enter__ starts the server.""" + client = MockLM() + handler = LMHandler(client) + + with handler as h: + assert h is handler + assert handler._server is not None + + def test_exit_stops_server(self): + """Test __exit__ stops the server.""" + client = MockLM() + handler = LMHandler(client) + + with handler: + pass + + assert handler._server is None + + def test_exit_returns_false(self): + """Test __exit__ returns False (doesn't suppress exceptions).""" + client = MockLM() + handler = LMHandler(client) + + result = handler.__exit__(None, None, None) + assert result is False + + +# Usage Summary Tests + + +class TestUsageSummary: + """Tests for get_usage_summary method.""" + + def test_usage_summary_single_client(self): + """Test usage summary with single client.""" + client = MockLM() + handler = LMHandler(client) + + summary = handler.get_usage_summary() + assert "mock-model" in summary.model_usage_summaries + + def test_usage_summary_with_other_backend(self): + """Test usage summary includes other_backend_client.""" + default_client = MockLM() + other_client = AnotherMockLM() + handler = LMHandler(default_client, other_backend_client=other_client) + + summary = handler.get_usage_summary() + assert "mock-model" in summary.model_usage_summaries + assert "another-mock-model" in summary.model_usage_summaries + + def test_usage_summary_merges_all_clients(self): + """Test usage summary merges all registered clients.""" + default_client = MockLM() + handler = LMHandler(default_client) + + extra_client = AnotherMockLM(model_name="extra-model") + handler.register_client("extra-model", extra_client) + + summary = handler.get_usage_summary() + assert "mock-model" in summary.model_usage_summaries + assert "extra-model" in summary.model_usage_summaries + + +# Integration Tests (with socket communication) + + +class TestSocketIntegration: + """Integration tests for socket-based communication.""" + + def test_handler_accepts_connection(self): + """Test that handler accepts socket connections.""" + client = MockLM() + + with LMHandler(client) as handler: + # Server should be running and accepting connections + assert handler._server is not None + assert handler.port > 0 + + def test_full_request_response_cycle(self): + """Test full request/response cycle through socket.""" + from rlm.core.comms_utils import LMRequest, send_lm_request + + client = MockLM() + + with LMHandler(client) as handler: + request = LMRequest(prompt="Test prompt") + response = send_lm_request(handler.address, request) + + assert response.success is True + assert "Mock response to: Test prompt" in response.chat_completion.response + + def test_request_with_depth_routing(self): + """Test request with depth parameter routes correctly.""" + from rlm.core.comms_utils import LMRequest, send_lm_request + + default_client = MockLM() + other_client = AnotherMockLM() + + with LMHandler(default_client, other_backend_client=other_client) as handler: + # depth=0 should use default + request0 = LMRequest(prompt="Test", depth=0) + response0 = send_lm_request(handler.address, request0) + assert "Mock response" in response0.chat_completion.response + + # depth=1 should use other backend + request1 = LMRequest(prompt="Test", depth=1) + response1 = send_lm_request(handler.address, request1) + assert "Another mock response" in response1.chat_completion.response + + def test_batched_request(self): + """Test batched request handling.""" + from rlm.core.comms_utils import send_lm_request_batched + + client = MockLM() + + with LMHandler(client) as handler: + prompts = ["Prompt 1", "Prompt 2", "Prompt 3"] + responses = send_lm_request_batched(handler.address, prompts=prompts) + + assert len(responses) == 3 + assert all(r.success for r in responses) + for i, response in enumerate(responses): + assert f"Prompt {i + 1}" in response.chat_completion.response diff --git a/tests/test_rlm_logger.py b/tests/test_rlm_logger.py new file mode 100644 index 00000000..7777e692 --- /dev/null +++ b/tests/test_rlm_logger.py @@ -0,0 +1,426 @@ +"""Tests for RLMLogger.""" + +import json +import os +import tempfile +from pathlib import Path + +from rlm.core.types import CodeBlock, REPLResult, RLMIteration, RLMMetadata +from rlm.logger.rlm_logger import RLMLogger + + +class TestRLMLoggerInitialization: + """Tests for RLMLogger initialization and file creation.""" + + def test_creates_log_directory(self): + """Logger should create the log directory if it doesn't exist.""" + with tempfile.TemporaryDirectory() as temp_dir: + log_dir = os.path.join(temp_dir, "nested", "logs") + logger = RLMLogger(log_dir) + + assert os.path.isdir(log_dir) + assert logger.log_dir == log_dir + + def test_log_file_path_format(self): + """Log file path should contain filename, timestamp, and run_id.""" + with tempfile.TemporaryDirectory() as temp_dir: + logger = RLMLogger(temp_dir, file_name="test_rlm") + log_path = Path(logger.log_file_path) + + assert log_path.suffix == ".jsonl" + assert log_path.name.startswith("test_rlm_") + # Format: test_rlm_YYYY-MM-DD_HH-MM-SS_xxxxxxxx.jsonl + parts = log_path.stem.split("_") + assert len(parts) >= 4 # file_name + date + time + run_id + + def test_default_file_name(self): + """Default file name should be 'rlm'.""" + with tempfile.TemporaryDirectory() as temp_dir: + logger = RLMLogger(temp_dir) + log_path = Path(logger.log_file_path) + + assert log_path.name.startswith("rlm_") + + def test_initial_iteration_count(self): + """Initial iteration count should be zero.""" + with tempfile.TemporaryDirectory() as temp_dir: + logger = RLMLogger(temp_dir) + + assert logger.iteration_count == 0 + + +class TestRLMLoggerMetadata: + """Tests for logging metadata.""" + + def test_log_metadata_creates_entry(self): + """log_metadata should write metadata as first entry.""" + with tempfile.TemporaryDirectory() as temp_dir: + logger = RLMLogger(temp_dir) + metadata = RLMMetadata( + root_model="gpt-4", + max_depth=3, + max_iterations=10, + backend="openai", + backend_kwargs={"api_key": "test"}, + environment_type="local", + environment_kwargs={}, + ) + + logger.log_metadata(metadata) + + with open(logger.log_file_path) as f: + entry = json.loads(f.readline()) + + assert entry["type"] == "metadata" + assert entry["root_model"] == "gpt-4" + assert entry["max_depth"] == 3 + assert entry["max_iterations"] == 10 + assert entry["backend"] == "openai" + assert entry["environment_type"] == "local" + assert "timestamp" in entry + + def test_log_metadata_only_once(self): + """log_metadata should only write once even if called multiple times.""" + with tempfile.TemporaryDirectory() as temp_dir: + logger = RLMLogger(temp_dir) + metadata = RLMMetadata( + root_model="gpt-4", + max_depth=3, + max_iterations=10, + backend="openai", + backend_kwargs={}, + environment_type="local", + environment_kwargs={}, + ) + + logger.log_metadata(metadata) + logger.log_metadata(metadata) + logger.log_metadata(metadata) + + with open(logger.log_file_path) as f: + lines = f.readlines() + + assert len(lines) == 1 + + def test_metadata_with_other_backends(self): + """Metadata should include other_backends if provided.""" + with tempfile.TemporaryDirectory() as temp_dir: + logger = RLMLogger(temp_dir) + metadata = RLMMetadata( + root_model="claude-3", + max_depth=2, + max_iterations=5, + backend="anthropic", + backend_kwargs={}, + environment_type="docker", + environment_kwargs={"image": "python:3.11"}, + other_backends=["openai", "gemini"], + ) + + logger.log_metadata(metadata) + + with open(logger.log_file_path) as f: + entry = json.loads(f.readline()) + + assert entry["other_backends"] == ["openai", "gemini"] + + +class TestRLMLoggerIteration: + """Tests for logging iterations.""" + + def test_log_iteration_increments_count(self): + """Each log call should increment iteration count.""" + with tempfile.TemporaryDirectory() as temp_dir: + logger = RLMLogger(temp_dir) + iteration = RLMIteration( + prompt="Test prompt", + response="Test response", + code_blocks=[], + ) + + assert logger.iteration_count == 0 + + logger.log(iteration) + assert logger.iteration_count == 1 + + logger.log(iteration) + assert logger.iteration_count == 2 + + logger.log(iteration) + assert logger.iteration_count == 3 + + def test_log_iteration_writes_entry(self): + """log should write iteration data to file.""" + with tempfile.TemporaryDirectory() as temp_dir: + logger = RLMLogger(temp_dir) + iteration = RLMIteration( + prompt="Calculate 1+1", + response="Let me compute that.", + code_blocks=[], + final_answer="2", + ) + + logger.log(iteration) + + with open(logger.log_file_path) as f: + entry = json.loads(f.readline()) + + assert entry["type"] == "iteration" + assert entry["iteration"] == 1 + assert entry["prompt"] == "Calculate 1+1" + assert entry["response"] == "Let me compute that." + assert entry["final_answer"] == "2" + assert "timestamp" in entry + + def test_log_iteration_with_code_blocks(self): + """Iteration with code blocks should serialize correctly.""" + with tempfile.TemporaryDirectory() as temp_dir: + logger = RLMLogger(temp_dir) + code_result = REPLResult(stdout="42", stderr="", locals={"x": 42}) + iteration = RLMIteration( + prompt="Compute something", + response="Running code...", + code_blocks=[ + CodeBlock(code="x = 6 * 7\nprint(x)", result=code_result), + ], + ) + + logger.log(iteration) + + with open(logger.log_file_path) as f: + entry = json.loads(f.readline()) + + assert len(entry["code_blocks"]) == 1 + code_block = entry["code_blocks"][0] + assert code_block["code"] == "x = 6 * 7\nprint(x)" + assert code_block["result"]["stdout"] == "42" + assert code_block["result"]["stderr"] == "" + + def test_log_multiple_iterations(self): + """Multiple iterations should be logged with correct iteration numbers.""" + with tempfile.TemporaryDirectory() as temp_dir: + logger = RLMLogger(temp_dir) + + for i in range(3): + iteration = RLMIteration( + prompt=f"Prompt {i + 1}", + response=f"Response {i + 1}", + code_blocks=[], + ) + logger.log(iteration) + + with open(logger.log_file_path) as f: + lines = f.readlines() + + assert len(lines) == 3 + + for i, line in enumerate(lines): + entry = json.loads(line) + assert entry["iteration"] == i + 1 + assert entry["prompt"] == f"Prompt {i + 1}" + + +class TestRLMLoggerJSONLFormat: + """Tests for JSONL format validation.""" + + def test_each_line_is_valid_json(self): + """Each line in the log file should be valid JSON.""" + with tempfile.TemporaryDirectory() as temp_dir: + logger = RLMLogger(temp_dir) + metadata = RLMMetadata( + root_model="test-model", + max_depth=1, + max_iterations=5, + backend="openai", + backend_kwargs={}, + environment_type="local", + environment_kwargs={}, + ) + logger.log_metadata(metadata) + + for i in range(5): + iteration = RLMIteration( + prompt=f"Prompt {i}", + response=f"Response {i}", + code_blocks=[], + ) + logger.log(iteration) + + with open(logger.log_file_path) as f: + for line_num, line in enumerate(f, 1): + try: + json.loads(line) + except json.JSONDecodeError as e: + raise AssertionError(f"Line {line_num} is not valid JSON: {e}") from e + + def test_entries_end_with_newline(self): + """Each entry should end with a newline character.""" + with tempfile.TemporaryDirectory() as temp_dir: + logger = RLMLogger(temp_dir) + iteration = RLMIteration( + prompt="Test", + response="Response", + code_blocks=[], + ) + logger.log(iteration) + logger.log(iteration) + + with open(logger.log_file_path, "rb") as f: + content = f.read() + + # Count newlines - should match number of entries + newline_count = content.count(b"\n") + assert newline_count == 2 + + def test_metadata_and_iterations_in_order(self): + """Metadata should come first, followed by iterations.""" + with tempfile.TemporaryDirectory() as temp_dir: + logger = RLMLogger(temp_dir) + metadata = RLMMetadata( + root_model="gpt-4", + max_depth=2, + max_iterations=3, + backend="openai", + backend_kwargs={}, + environment_type="local", + environment_kwargs={}, + ) + logger.log_metadata(metadata) + + for _ in range(3): + iteration = RLMIteration( + prompt="Test", + response="Response", + code_blocks=[], + ) + logger.log(iteration) + + with open(logger.log_file_path) as f: + lines = f.readlines() + + assert len(lines) == 4 + + first_entry = json.loads(lines[0]) + assert first_entry["type"] == "metadata" + + for i, line in enumerate(lines[1:], 1): + entry = json.loads(line) + assert entry["type"] == "iteration" + assert entry["iteration"] == i + + +class TestRLMLoggerComplexData: + """Tests for logging complex data structures.""" + + def test_iteration_with_dict_prompt(self): + """Iteration with dict prompt should serialize correctly.""" + with tempfile.TemporaryDirectory() as temp_dir: + logger = RLMLogger(temp_dir) + iteration = RLMIteration( + prompt={"role": "user", "content": "Hello"}, + response="Hi there!", + code_blocks=[], + ) + + logger.log(iteration) + + with open(logger.log_file_path) as f: + entry = json.loads(f.readline()) + + assert entry["prompt"] == {"role": "user", "content": "Hello"} + + def test_iteration_with_multiple_code_blocks(self): + """Iteration with multiple code blocks should log all blocks.""" + with tempfile.TemporaryDirectory() as temp_dir: + logger = RLMLogger(temp_dir) + iteration = RLMIteration( + prompt="Multi-step computation", + response="Running multiple blocks...", + code_blocks=[ + CodeBlock( + code="x = 10", + result=REPLResult(stdout="", stderr="", locals={"x": 10}), + ), + CodeBlock( + code="y = x * 2", + result=REPLResult(stdout="", stderr="", locals={"x": 10, "y": 20}), + ), + CodeBlock( + code="print(y)", + result=REPLResult(stdout="20", stderr="", locals={"x": 10, "y": 20}), + ), + ], + ) + + logger.log(iteration) + + with open(logger.log_file_path) as f: + entry = json.loads(f.readline()) + + assert len(entry["code_blocks"]) == 3 + assert entry["code_blocks"][0]["code"] == "x = 10" + assert entry["code_blocks"][1]["code"] == "y = x * 2" + assert entry["code_blocks"][2]["code"] == "print(y)" + assert entry["code_blocks"][2]["result"]["stdout"] == "20" + + def test_code_block_with_stderr(self): + """Code block with stderr should be logged correctly.""" + with tempfile.TemporaryDirectory() as temp_dir: + logger = RLMLogger(temp_dir) + iteration = RLMIteration( + prompt="Error test", + response="This will fail", + code_blocks=[ + CodeBlock( + code="1 / 0", + result=REPLResult( + stdout="", + stderr="ZeroDivisionError: division by zero", + locals={}, + ), + ), + ], + ) + + logger.log(iteration) + + with open(logger.log_file_path) as f: + entry = json.loads(f.readline()) + + assert entry["code_blocks"][0]["result"]["stderr"] == ( + "ZeroDivisionError: division by zero" + ) + + def test_metadata_with_complex_kwargs(self): + """Metadata with complex backend/environment kwargs should serialize.""" + with tempfile.TemporaryDirectory() as temp_dir: + logger = RLMLogger(temp_dir) + metadata = RLMMetadata( + root_model="gpt-4-turbo", + max_depth=5, + max_iterations=20, + backend="azure_openai", + backend_kwargs={ + "api_key": "sk-xxx", + "api_version": "2024-01-01", + "azure_endpoint": "https://example.openai.azure.com", + "timeout": 30, + "max_retries": 3, + }, + environment_type="docker", + environment_kwargs={ + "image": "python:3.11-slim", + "memory_limit": "2g", + "cpu_limit": 2, + "volumes": ["/data:/data:ro"], + }, + ) + + logger.log_metadata(metadata) + + with open(logger.log_file_path) as f: + entry = json.loads(f.readline()) + + assert entry["backend_kwargs"]["api_version"] == "2024-01-01" + assert entry["environment_kwargs"]["memory_limit"] == "2g" + assert entry["environment_kwargs"]["volumes"] == ["/data:/data:ro"]