diff --git a/graphiti_core/cross_encoder/openai_reranker_client.py b/graphiti_core/cross_encoder/openai_reranker_client.py index 2e6c5b2ff..16f60065e 100644 --- a/graphiti_core/cross_encoder/openai_reranker_client.py +++ b/graphiti_core/cross_encoder/openai_reranker_client.py @@ -22,7 +22,8 @@ from openai import AsyncAzureOpenAI, AsyncOpenAI from ..helpers import semaphore_gather -from ..llm_client import LLMConfig, OpenAIClient, RateLimitError +from ..llm_client import LLMConfig, RateLimitError +from ..llm_client.openai_base_client import BaseOpenAIClient from ..prompts import Message from .client import CrossEncoderClient @@ -35,7 +36,7 @@ class OpenAIRerankerClient(CrossEncoderClient): def __init__( self, config: LLMConfig | None = None, - client: AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None = None, + client: AsyncOpenAI | AsyncAzureOpenAI | BaseOpenAIClient | None = None, ): """ Initialize the OpenAIRerankerClient with the provided configuration and client. @@ -45,7 +46,7 @@ def __init__( Args: config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens. - client (AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. + client (AsyncOpenAI | AsyncAzureOpenAI | BaseOpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created. """ if config is None: config = LLMConfig() @@ -53,7 +54,7 @@ def __init__( self.config = config if client is None: self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) - elif isinstance(client, OpenAIClient): + elif isinstance(client, BaseOpenAIClient): self.client = client.client else: self.client = client diff --git a/graphiti_core/llm_client/openai_base_client.py b/graphiti_core/llm_client/openai_base_client.py index 93e9c598e..259272704 100644 --- a/graphiti_core/llm_client/openai_base_client.py +++ b/graphiti_core/llm_client/openai_base_client.py @@ -21,6 +21,7 @@ from typing import Any, ClassVar import openai +from openai import AsyncAzureOpenAI, AsyncOpenAI from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel @@ -48,6 +49,9 @@ class BaseOpenAIClient(LLMClient): # Class-level constants MAX_RETRIES: ClassVar[int] = 2 + # Instance attribute (initialized in subclasses) + client: AsyncOpenAI | AsyncAzureOpenAI + def __init__( self, config: LLMConfig | None = None, diff --git a/tests/cross_encoder/test_openai_reranker_client.py b/tests/cross_encoder/test_openai_reranker_client.py new file mode 100644 index 000000000..8dc3ae540 --- /dev/null +++ b/tests/cross_encoder/test_openai_reranker_client.py @@ -0,0 +1,145 @@ +""" +Test file for OpenAIRerankerClient, specifically testing compatibility with +both OpenAIClient and AzureOpenAILLMClient instances. + +This test validates the fix for issue #1006 where OpenAIRerankerClient +failed to properly support AzureOpenAILLMClient. +""" + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient +from graphiti_core.llm_client import LLMConfig +from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient +from graphiti_core.llm_client.openai_client import OpenAIClient + + +class MockAsyncOpenAI: + """Mock AsyncOpenAI client for testing""" + + def __init__(self, api_key=None, base_url=None): + self.api_key = api_key + self.base_url = base_url + self.chat = MagicMock() + self.chat.completions = MagicMock() + self.chat.completions.create = AsyncMock() + + +class MockAsyncAzureOpenAI: + """Mock AsyncAzureOpenAI client for testing""" + + def __init__(self): + self.chat = MagicMock() + self.chat.completions = MagicMock() + self.chat.completions.create = AsyncMock() + + +@pytest.fixture +def mock_openai_client(): + """Fixture to create a mocked OpenAIClient""" + client = OpenAIClient(config=LLMConfig(api_key='test-key')) + # Replace the internal client with our mock + client.client = MockAsyncOpenAI() + return client + + +@pytest.fixture +def mock_azure_openai_client(): + """Fixture to create a mocked AzureOpenAILLMClient""" + mock_azure = MockAsyncAzureOpenAI() + client = AzureOpenAILLMClient( + azure_client=mock_azure, + config=LLMConfig(api_key='test-key') + ) + return client + + +def test_openai_reranker_accepts_openai_client(mock_openai_client): + """Test that OpenAIRerankerClient properly unwraps OpenAIClient""" + # Create reranker with OpenAIClient + reranker = OpenAIRerankerClient(client=mock_openai_client) + + # Verify the internal client is the unwrapped AsyncOpenAI instance + assert reranker.client == mock_openai_client.client + assert hasattr(reranker.client, 'chat') + + +def test_openai_reranker_accepts_azure_client(mock_azure_openai_client): + """Test that OpenAIRerankerClient properly unwraps AzureOpenAILLMClient + + This test validates the fix for issue #1006. + """ + # Create reranker with AzureOpenAILLMClient - this would fail before the fix + reranker = OpenAIRerankerClient(client=mock_azure_openai_client) + + # Verify the internal client is the unwrapped AsyncAzureOpenAI instance + assert reranker.client == mock_azure_openai_client.client + assert hasattr(reranker.client, 'chat') + + +def test_openai_reranker_accepts_async_openai_directly(): + """Test that OpenAIRerankerClient accepts AsyncOpenAI directly""" + # Create a mock AsyncOpenAI + mock_async = MockAsyncOpenAI(api_key='test-key') + + # Create reranker with AsyncOpenAI directly + reranker = OpenAIRerankerClient(client=mock_async) + + # Verify the internal client is used as-is + assert reranker.client == mock_async + assert hasattr(reranker.client, 'chat') + + +def test_openai_reranker_creates_default_client(): + """Test that OpenAIRerankerClient creates a default client when none provided""" + config = LLMConfig(api_key='test-key') + + # Create reranker without client + reranker = OpenAIRerankerClient(config=config) + + # Verify a client was created + assert reranker.client is not None + # The default should be an AsyncOpenAI instance + from openai import AsyncOpenAI + assert isinstance(reranker.client, AsyncOpenAI) + + +@pytest.mark.asyncio +async def test_rank_method_with_azure_client(mock_azure_openai_client): + """Test that rank method works correctly with AzureOpenAILLMClient""" + # Setup mock response for the chat completions + mock_response = SimpleNamespace( + choices=[ + SimpleNamespace( + logprobs=SimpleNamespace( + content=[ + SimpleNamespace( + top_logprobs=[ + SimpleNamespace(token='True', logprob=-0.5) + ] + ) + ] + ) + ) + ] + ) + + mock_azure_openai_client.client.chat.completions.create.return_value = mock_response + + # Create reranker with AzureOpenAILLMClient + reranker = OpenAIRerankerClient(client=mock_azure_openai_client) + + # Test ranking + query = "test query" + passages = ["passage 1"] + + # This would previously fail with AttributeError before the fix + results = await reranker.rank(query, passages) + + # Verify the method was called + assert mock_azure_openai_client.client.chat.completions.create.called + assert len(results) == 1 + assert results[0][0] == "passage 1"