Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250917154150573751.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add retry on llm calls"
}
29 changes: 29 additions & 0 deletions benchmark_qed/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,31 @@ class AuthType(StrEnum):
AzureManagedIdentity = "azure_managed_identity"


class RetryConfig(BaseModel):
"""Configuration for retrying failed requests."""

retries: int = Field(
default=3,
description="The maximum number of retry attempts.",
)
base_delay: float = Field(
default=1.0,
description="Initial delay between retries in seconds.",
)
max_delay: float = Field(
default=60.0,
description="Maximum delay between retries in seconds.",
)
backoff_factor: float = Field(
default=2.0,
description="Multiplier for exponential backoff.",
)
jitter: bool = Field(
default=True,
description="Whether to add random jitter to delay times.",
)


class LLMConfig(BaseModel):
"""Configuration for the LLM to use."""

Expand All @@ -73,6 +98,10 @@ class LLMConfig(BaseModel):
default=4,
description="The number of concurrent requests to send to the model. This should be a positive integer.",
)
retry_config: RetryConfig = Field(
default_factory=RetryConfig,
description="Configuration for retry behavior.",
)
llm_provider: LLMProvider | str = Field(
default=LLMProvider.OpenAIChat,
description="The type of model to use.",
Expand Down
109 changes: 93 additions & 16 deletions benchmark_qed/llm/provider/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,44 @@
from azure.ai.inference.aio import ChatCompletionsClient, EmbeddingsClient
from azure.ai.inference.models import ChatCompletions, EmbeddingEncodingFormat
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import HttpResponseError
from azure.identity import DefaultAzureCredential

from benchmark_qed.config.llm_config import AuthType, LLMConfig
from tenacity import (
AsyncRetrying,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)

from benchmark_qed.config.llm_config import AuthType, LLMConfig, RetryConfig
from benchmark_qed.llm.type.base import BaseModelOutput, BaseModelResponse, Usage

# Common retryable exceptions for Azure services
RETRYABLE_EXCEPTIONS: tuple[type[Exception], ...] = (HttpResponseError,)


def _async_retry(retry_config: RetryConfig) -> AsyncRetrying:
"""Create a tenacity AsyncRetrying instance from LLMConfig.

Args:
llm_config: The LLM configuration containing retry settings.

Returns
-------
AsyncRetrying instance configured with the provided settings.
"""
return AsyncRetrying(
stop=stop_after_attempt(retry_config.retries),
wait=wait_exponential_jitter(
initial=retry_config.base_delay,
max=retry_config.max_delay,
jitter=retry_config.base_delay * 0.25 if retry_config.jitter else 0,
exp_base=retry_config.backoff_factor,
),
retry=retry_if_exception_type(RETRYABLE_EXCEPTIONS),
reraise=True,
)


class AzureInferenceChat:
"""An Azure Chat Model provider."""
Expand All @@ -29,11 +62,34 @@ def __init__(self, llm_config: LLMConfig) -> None:
self._model = llm_config.model
self._semaphore = asyncio.Semaphore(llm_config.concurrent_requests)
self._usage = Usage(model=llm_config.model)
self._retry_config = llm_config.retry_config

def get_usage(self) -> dict[str, Any]:
"""Get the usage of the Model."""
return self._usage.model_dump()

async def _complete_chat(
self, messages: list[dict[str, str]], **kwargs: dict[str, Any]
) -> ChatCompletions:
"""Complete a chat request using the Azure client.

Args:
messages: The messages to send to the model.
kwargs: Additional arguments to pass to the model.

Returns
-------
The chat completion response.
"""
return cast(
ChatCompletions,
await self._client.complete(
model=self._model,
messages=messages,
**kwargs, # type: ignore
), # type: ignore
)

async def chat(
self, messages: list[dict[str, str]], **kwargs: dict[str, Any]
) -> BaseModelResponse:
Expand All @@ -48,15 +104,15 @@ async def chat(
-------
The response from the Model.
"""
response = None
async with self._semaphore:
response: ChatCompletions = cast(
ChatCompletions,
await self._client.complete(
model=self._model,
messages=messages,
**kwargs, # type: ignore
), # type: ignore
)
async for attempt in _async_retry(self._retry_config):
with attempt:
response = await self._complete_chat(messages, **kwargs)

if response is None:
msg = "No response received from Azure Chat API"
raise ValueError(msg)

content = response.choices[0].message.content.replace(
"<|im_start|>assistant<|im_sep|>", ""
Expand Down Expand Up @@ -103,11 +159,30 @@ def __init__(self, llm_config: LLMConfig) -> None:
self._model = llm_config.model
self._semaphore = asyncio.Semaphore(llm_config.concurrent_requests)
self._usage = Usage(model=llm_config.model)
self._retry_config = llm_config.retry_config

def get_usage(self) -> dict[str, Any]:
"""Get the usage of the Model."""
return self._usage.model_dump()

async def _embed_text(self, text_list: list[str], **kwargs: Any) -> Any:
"""Generate embeddings using the Azure client.

Args:
text_list: The list of text to generate embeddings for.
kwargs: Additional arguments to pass to the model.

Returns
-------
The embedding response.
"""
return await self._client.embed(
model=self._model,
input=text_list,
encoding_format=EmbeddingEncodingFormat.FLOAT,
**kwargs,
)

async def embed(self, text_list: list[str], **kwargs: Any) -> list[list[float]]:
"""
Generate an embedding vector for the given list of strings.
Expand All @@ -120,13 +195,15 @@ async def embed(self, text_list: list[str], **kwargs: Any) -> list[list[float]]:
-------
A collections of list of floats representing the embedding vector for each item in the batch.
"""
response = None
async with self._semaphore:
response = await self._client.embed(
model=self._model,
input=text_list,
encoding_format=EmbeddingEncodingFormat.FLOAT,
**kwargs,
)
async for attempt in _async_retry(self._retry_config):
with attempt:
response = await self._embed_text(text_list, **kwargs)

if response is None:
msg = "No response received from Azure Embedding API"
raise ValueError(msg)

self._usage.add_usage(prompt_tokens=response.usage.prompt_tokens)

Expand Down
112 changes: 99 additions & 13 deletions benchmark_qed/llm/provider/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,55 @@
from typing import Any

from azure.identity import DefaultAzureCredential, get_bearer_token_provider
from openai import AsyncAzureOpenAI, AsyncOpenAI

from benchmark_qed.config.llm_config import AuthType, LLMConfig
from openai import (
APITimeoutError,
AsyncAzureOpenAI,
AsyncOpenAI,
InternalServerError,
RateLimitError,
)
from tenacity import (
AsyncRetrying,
retry_if_exception_type,
stop_after_attempt,
wait_exponential_jitter,
)

from benchmark_qed.config.llm_config import AuthType, LLMConfig, RetryConfig
from benchmark_qed.llm.type.base import BaseModelOutput, BaseModelResponse, Usage

REASONING_MODELS = ["o3", "o4-mini", "o3-mini", "o1-mini", "o1", "o1-pro"]

# Common retryable exceptions for OpenAI services
RETRYABLE_EXCEPTIONS: tuple[type[Exception], ...] = (
APITimeoutError,
InternalServerError,
RateLimitError,
)


def async_retry(retry_config: RetryConfig) -> AsyncRetrying:
"""Create a tenacity AsyncRetrying instance from LLMConfig.

Args:
llm_config: The LLM configuration containing retry settings.

Returns
-------
AsyncRetrying instance configured with the provided settings.
"""
return AsyncRetrying(
stop=stop_after_attempt(retry_config.retries),
wait=wait_exponential_jitter(
initial=retry_config.base_delay,
max=retry_config.max_delay,
jitter=retry_config.base_delay * 0.25 if retry_config.jitter else 0,
exp_base=retry_config.backoff_factor,
),
retry=retry_if_exception_type(RETRYABLE_EXCEPTIONS),
reraise=True,
)


class BaseOpenAIChat:
"""An OpenAI Chat Model provider."""
Expand All @@ -23,11 +65,31 @@ def __init__(
self._model = llm_config.model
self._semaphore = asyncio.Semaphore(llm_config.concurrent_requests)
self._usage = Usage(model=llm_config.model)
self._retry_config = llm_config.retry_config

def get_usage(self) -> dict[str, Any]:
"""Get the usage of the Model."""
return self._usage.model_dump()

async def _create_chat_completion(
self, messages: list[dict[str, str]], **kwargs: dict[str, Any]
) -> Any:
"""Create a chat completion using the OpenAI client.

Args:
messages: The messages to send to the model.
kwargs: Additional arguments to pass to the model.

Returns
-------
The chat completion response.
"""
return await self._client.chat.completions.create(
model=self._model,
messages=messages, # type: ignore
**kwargs, # type: ignore
)

async def chat(
self, messages: list[dict[str, str]], **kwargs: dict[str, Any]
) -> BaseModelResponse:
Expand All @@ -45,12 +107,15 @@ async def chat(
if self._model in REASONING_MODELS and "temperature" in kwargs:
kwargs.pop("temperature")

response = None
async with self._semaphore:
response = await self._client.chat.completions.create(
model=self._model,
messages=messages, # type: ignore
**kwargs, # type: ignore
)
async for attempt in async_retry(self._retry_config):
with attempt:
response = await self._create_chat_completion(messages, **kwargs)

if response is None:
msg = "No response received from Azure Chat API"
raise ValueError(msg)

history = [
*messages,
Expand Down Expand Up @@ -134,11 +199,29 @@ def __init__(
self._model = llm_config.model
self._semaphore = asyncio.Semaphore(llm_config.concurrent_requests)
self._usage = Usage(model=llm_config.model)
self._retry_config = llm_config.retry_config

def get_usage(self) -> dict[str, Any]:
"""Get the usage of the Model."""
return self._usage.model_dump()

async def _create_embeddings(self, text_list: list[str], **kwargs: Any) -> Any:
"""Create embeddings using the OpenAI client.

Args:
text_list: The list of text to generate embeddings for.
kwargs: Additional arguments to pass to the model.

Returns
-------
The embedding response.
"""
return await self._client.embeddings.create(
model=self._model,
input=text_list,
**kwargs,
)

async def embed(self, text_list: list[str], **kwargs: Any) -> list[list[float]]:
"""
Generate an embedding vector for the given list of strings.
Expand All @@ -151,12 +234,15 @@ async def embed(self, text_list: list[str], **kwargs: Any) -> list[list[float]]:
-------
A collections of list of floats representing the embedding vector for each item in the batch.
"""
response = None
async with self._semaphore:
response = await self._client.embeddings.create(
model=self._model,
input=text_list,
**kwargs,
)
async for attempt in async_retry(self._retry_config):
with attempt:
response = await self._create_embeddings(text_list, **kwargs)

if response is None:
msg = "No response received from Azure Chat API"
raise ValueError(msg)

self._usage.add_usage(prompt_tokens=response.usage.prompt_tokens)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"statsmodels>=0.14.4",
"tiktoken>=0.9.0",
"typer>=0.15.1",
"tenacity>=9.1.2",
]

[dependency-groups]
Expand Down
Loading