Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions graphiti_core/cross_encoder/openai_reranker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from openai import AsyncAzureOpenAI, AsyncOpenAI

from ..helpers import semaphore_gather
from ..llm_client import LLMConfig, OpenAIClient, RateLimitError
from ..llm_client import LLMConfig, RateLimitError
from ..llm_client.openai_base_client import BaseOpenAIClient
from ..prompts import Message
from .client import CrossEncoderClient

Expand All @@ -35,7 +36,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
def __init__(
self,
config: LLMConfig | None = None,
client: AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None = None,
client: AsyncOpenAI | AsyncAzureOpenAI | BaseOpenAIClient | None = None,
):
"""
Initialize the OpenAIRerankerClient with the provided configuration and client.
Expand All @@ -45,15 +46,15 @@ def __init__(

Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
client (AsyncOpenAI | AsyncAzureOpenAI | OpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
client (AsyncOpenAI | AsyncAzureOpenAI | BaseOpenAIClient | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
"""
if config is None:
config = LLMConfig()

self.config = config
if client is None:
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
elif isinstance(client, OpenAIClient):
elif isinstance(client, BaseOpenAIClient):
self.client = client.client
else:
self.client = client
Expand Down
4 changes: 4 additions & 0 deletions graphiti_core/llm_client/openai_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Any, ClassVar

import openai
from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel

Expand Down Expand Up @@ -48,6 +49,9 @@ class BaseOpenAIClient(LLMClient):
# Class-level constants
MAX_RETRIES: ClassVar[int] = 2

# Instance attribute (initialized in subclasses)
client: AsyncOpenAI | AsyncAzureOpenAI

def __init__(
self,
config: LLMConfig | None = None,
Expand Down
145 changes: 145 additions & 0 deletions tests/cross_encoder/test_openai_reranker_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""
Test file for OpenAIRerankerClient, specifically testing compatibility with
both OpenAIClient and AzureOpenAILLMClient instances.

This test validates the fix for issue #1006 where OpenAIRerankerClient
failed to properly support AzureOpenAILLMClient.
"""

from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock

import pytest

from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.llm_client import LLMConfig
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
from graphiti_core.llm_client.openai_client import OpenAIClient


class MockAsyncOpenAI:
"""Mock AsyncOpenAI client for testing"""

def __init__(self, api_key=None, base_url=None):
self.api_key = api_key
self.base_url = base_url
self.chat = MagicMock()
self.chat.completions = MagicMock()
self.chat.completions.create = AsyncMock()


class MockAsyncAzureOpenAI:
"""Mock AsyncAzureOpenAI client for testing"""

def __init__(self):
self.chat = MagicMock()
self.chat.completions = MagicMock()
self.chat.completions.create = AsyncMock()


@pytest.fixture
def mock_openai_client():
"""Fixture to create a mocked OpenAIClient"""
client = OpenAIClient(config=LLMConfig(api_key='test-key'))
# Replace the internal client with our mock
client.client = MockAsyncOpenAI()
return client


@pytest.fixture
def mock_azure_openai_client():
"""Fixture to create a mocked AzureOpenAILLMClient"""
mock_azure = MockAsyncAzureOpenAI()
client = AzureOpenAILLMClient(
azure_client=mock_azure,
config=LLMConfig(api_key='test-key')
)
return client


def test_openai_reranker_accepts_openai_client(mock_openai_client):
"""Test that OpenAIRerankerClient properly unwraps OpenAIClient"""
# Create reranker with OpenAIClient
reranker = OpenAIRerankerClient(client=mock_openai_client)

# Verify the internal client is the unwrapped AsyncOpenAI instance
assert reranker.client == mock_openai_client.client
assert hasattr(reranker.client, 'chat')


def test_openai_reranker_accepts_azure_client(mock_azure_openai_client):
"""Test that OpenAIRerankerClient properly unwraps AzureOpenAILLMClient

This test validates the fix for issue #1006.
"""
# Create reranker with AzureOpenAILLMClient - this would fail before the fix
reranker = OpenAIRerankerClient(client=mock_azure_openai_client)

# Verify the internal client is the unwrapped AsyncAzureOpenAI instance
assert reranker.client == mock_azure_openai_client.client
assert hasattr(reranker.client, 'chat')


def test_openai_reranker_accepts_async_openai_directly():
"""Test that OpenAIRerankerClient accepts AsyncOpenAI directly"""
# Create a mock AsyncOpenAI
mock_async = MockAsyncOpenAI(api_key='test-key')

# Create reranker with AsyncOpenAI directly
reranker = OpenAIRerankerClient(client=mock_async)

# Verify the internal client is used as-is
assert reranker.client == mock_async
assert hasattr(reranker.client, 'chat')


def test_openai_reranker_creates_default_client():
"""Test that OpenAIRerankerClient creates a default client when none provided"""
config = LLMConfig(api_key='test-key')

# Create reranker without client
reranker = OpenAIRerankerClient(config=config)

# Verify a client was created
assert reranker.client is not None
# The default should be an AsyncOpenAI instance
from openai import AsyncOpenAI
assert isinstance(reranker.client, AsyncOpenAI)


@pytest.mark.asyncio
async def test_rank_method_with_azure_client(mock_azure_openai_client):
"""Test that rank method works correctly with AzureOpenAILLMClient"""
# Setup mock response for the chat completions
mock_response = SimpleNamespace(
choices=[
SimpleNamespace(
logprobs=SimpleNamespace(
content=[
SimpleNamespace(
top_logprobs=[
SimpleNamespace(token='True', logprob=-0.5)
]
)
]
)
)
]
)

mock_azure_openai_client.client.chat.completions.create.return_value = mock_response

# Create reranker with AzureOpenAILLMClient
reranker = OpenAIRerankerClient(client=mock_azure_openai_client)

# Test ranking
query = "test query"
passages = ["passage 1"]

# This would previously fail with AttributeError before the fix
results = await reranker.rank(query, passages)

# Verify the method was called
assert mock_azure_openai_client.client.chat.completions.create.called
assert len(results) == 1
assert results[0][0] == "passage 1"
Loading