Skip to content

Commit

Permalink
Add ModelConfig for better chat model routing.
Browse files Browse the repository at this point in the history
  • Loading branch information
zh-plus committed Dec 19, 2024
1 parent 51dd6e4 commit 6efe333
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 62 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,25 @@ e.g. [OpenAI-GPT](https://github.com/openai/openai-python), [Anthropic-Claude](h
```shell
pip install "faster-whisper @ https://github.com/SYSTRAN/faster-whisper/archive/8327d8cc647266ed66f6cd878cf97eccface7351.tar.gz"
```
- 2024.12.19: Add `ModelConfig` for chat model routing, which is more flexible than model name string, The ModelConfig
can be ModelConfig(provider='<provider>', model_name='<model-name>', base_url='<url>', proxy='<proxy>'), e.g.:
```python

from openlrc import LRCer, ModelConfig, ModelProvider

chatbot_model1 = ModelConfig(
provider=ModelProvider.OPENAI,
name='deepseek-chat',
base_url='https://api.deepseek.com/beta',
api_key='sk-APIKEY'
)
chatbot_model2 = ModelConfig(
provider=ModelProvider.OPENAI,
name='gpt-4o-mini',
api_key='sk-APIKEY'
)
lrcer = LRCer(chatbot_model=chatbot_model1, retry_model=chatbot_model2)
```

## Installation ⚙️

Expand Down
4 changes: 2 additions & 2 deletions openlrc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright (C) 2024. Hao Zheng
# All rights reserved.

from openlrc.models import list_chatbot_models
from openlrc.models import list_chatbot_models, ModelConfig, ModelProvider
from openlrc.openlrc import LRCer

__all__ = ('LRCer',)
__all__ = ('LRCer', 'ModelConfig', 'list_chatbot_models', 'ModelProvider')
__version__ = '1.5.2'
__author__ = 'zh-plus'
53 changes: 37 additions & 16 deletions openlrc/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

from json_repair import repair_json

from openlrc.chatbot import route_chatbot, GPTBot, ClaudeBot, GeminiBot
from openlrc.chatbot import route_chatbot, GPTBot, ClaudeBot, GeminiBot, provider2chatbot
from openlrc.context import TranslationContext, TranslateInfo
from openlrc.logger import logger
from openlrc.models import ModelConfig, ModelProvider
from openlrc.prompter import ChunkedTranslatePrompter, ContextReviewPrompter, ProofreaderPrompter, PROOFREAD_PREFIX, \
ContextReviewerValidatePrompter, TranslationEvaluatorPrompter
from openlrc.validators import POTENTIAL_PREFIX_COMBOS
Expand All @@ -24,23 +25,42 @@ class Agent(abc.ABC):
"""
TEMPERATURE = 1

def _initialize_chatbot(self, chatbot_model: str, fee_limit: float, proxy: str, base_url_config: Optional[dict]):
def _initialize_chatbot(self, chatbot_model: Union[str, ModelConfig], fee_limit: float, proxy: str,
base_url_config: Optional[dict]):
"""
Initialize a chatbot instance based on the provided parameters.
Args:
chatbot_model (str): The name of the chatbot model to use.
chatbot_model (Union[str, ModelConfig]): The name of the chatbot model or ModelConfig.
fee_limit (float): The maximum fee allowed for API calls.
proxy (str): Proxy server to use for API calls.
base_url_config (Optional[dict]): Configuration for the base URL of the API.
Returns:
Union[ClaudeBot, GPTBot]: An instance of the appropriate chatbot class.
"""
chatbot_cls: Union[Type[ClaudeBot], Type[GPTBot], Type[GeminiBot]]
chatbot_cls, model_name = route_chatbot(chatbot_model)
return chatbot_cls(model_name=model_name, fee_limit=fee_limit, proxy=proxy, retry=2,
temperature=self.TEMPERATURE, base_url_config=base_url_config)

if isinstance(chatbot_model, str):
chatbot_cls: Union[Type[ClaudeBot], Type[GPTBot], Type[GeminiBot]]
chatbot_cls, model_name = route_chatbot(chatbot_model)
return chatbot_cls(model_name=model_name, fee_limit=fee_limit, proxy=proxy, retry=2,
temperature=self.TEMPERATURE, base_url_config=base_url_config)
elif isinstance(chatbot_model, ModelConfig):
chatbot_cls = provider2chatbot[chatbot_model.provider]
proxy = chatbot_model.proxy or proxy

if chatbot_model.base_url:
if chatbot_model.provider == ModelProvider.OPENAI:
base_url_config = {'openai': chatbot_model.base_url}
elif chatbot_model.provider == ModelProvider.ANTHROPIC:
base_url_config = {'anthropic': chatbot_model.base_url}
else:
base_url_config = None
logger.warning(f'Unsupported base_url configuration for provider: {chatbot_model.provider}')

return chatbot_cls(model_name=chatbot_model.name, fee_limit=fee_limit, proxy=proxy, retry=2,
temperature=self.TEMPERATURE, base_url_config=base_url_config,
api_key=chatbot_model.api_key)


class ChunkedTranslatorAgent(Agent):
Expand All @@ -56,7 +76,7 @@ class ChunkedTranslatorAgent(Agent):
TEMPERATURE = 1.0

def __init__(self, src_lang, target_lang, info: TranslateInfo = TranslateInfo(),
chatbot_model: str = 'gpt-4o-mini', fee_limit: float = 0.8, proxy: str = None,
chatbot_model: Union[str, ModelConfig] = 'gpt-4o-mini', fee_limit: float = 0.8, proxy: str = None,
base_url_config: Optional[dict] = None):
"""
Initialize the ChunkedTranslatorAgent.
Expand All @@ -65,7 +85,7 @@ def __init__(self, src_lang, target_lang, info: TranslateInfo = TranslateInfo(),
src_lang (str): The source language.
target_lang (str): The target language for translation.
info (TranslateInfo): Additional translation information.
chatbot_model (str): The name of the chatbot model to use.
chatbot_model (Union[str, ModelConfig]): The name of the chatbot model or ModelConfig.
fee_limit (float): The maximum fee allowed for API calls.
proxy (str): Proxy server to use for API calls.
base_url_config (Optional[dict]): Configuration for the base URL of the API.
Expand Down Expand Up @@ -193,7 +213,7 @@ class ContextReviewerAgent(Agent):
TEMPERATURE = 0.6

def __init__(self, src_lang, target_lang, info: TranslateInfo = TranslateInfo(),
chatbot_model: str = 'gpt-4o-mini', retry_model=None,
chatbot_model: Union[str, ModelConfig] = 'gpt-4o-mini', retry_model=None,
fee_limit: float = 0.8, proxy: str = None,
base_url_config: Optional[dict] = None):
"""
Expand All @@ -203,8 +223,8 @@ def __init__(self, src_lang, target_lang, info: TranslateInfo = TranslateInfo(),
src_lang (str): The source language.
target_lang (str): The target language.
info (TranslateInfo): Additional translation information.
chatbot_model (str): The name of the primary chatbot model to use.
retry_model (str): The name of the backup chatbot model to use for retries.
chatbot_model (Union[str, ModelConfig]): The name or ModelConfig of the primary chatbot model.
retry_model (Union[str, ModelConfig]): The name or ModelConfig of the backup chatbot model to use for retries.
fee_limit (float): The maximum fee allowed for API calls.
proxy (str): Proxy server to use for API calls.
base_url_config (Optional[dict]): Configuration for the base URL of the API.
Expand Down Expand Up @@ -331,7 +351,7 @@ class ProofreaderAgent(Agent):
TEMPERATURE = 0.8

def __init__(self, src_lang, target_lang, info: TranslateInfo = TranslateInfo(),
chatbot_model: str = 'gpt-4o-mini', fee_limit: float = 0.8, proxy: str = None,
chatbot_model: Union[str, ModelConfig] = 'gpt-4o-mini', fee_limit: float = 0.8, proxy: str = None,
base_url_config: Optional[dict] = None):
"""
Initialize the ProofreaderAgent.
Expand All @@ -340,7 +360,7 @@ def __init__(self, src_lang, target_lang, info: TranslateInfo = TranslateInfo(),
src_lang (str): The source language.
target_lang (str): The target language.
info (TranslateInfo): Additional translation information.
chatbot_model (str): The name of the chatbot model to use.
chatbot_model (Union[str, ModelConfig]): The name or ModelConfig of the chatbot model to use.
fee_limit (float): The maximum fee allowed for API calls.
proxy (str): Proxy server to use for API calls.
base_url_config (Optional[dict]): Configuration for the base URL of the API.
Expand Down Expand Up @@ -404,13 +424,14 @@ class TranslationEvaluatorAgent(Agent):

TEMPERATURE = 0.95

def __init__(self, chatbot_model: str = 'gpt-4o-mini', fee_limit: float = 0.8, proxy: str = None,
def __init__(self, chatbot_model: Union[str, ModelConfig] = 'gpt-4o-mini', fee_limit: float = 0.8,
proxy: str = None,
base_url_config: Optional[dict] = None):
"""
Initialize the TranslationEvaluatorAgent.
Args:
chatbot_model (str): The name of the chatbot model to use.
chatbot_model (Union[str, ModelConfig]): The name of the chatbot model or ModelConfig.
fee_limit (float): The maximum fee allowed for API calls.
proxy (str): Proxy server to use for API calls.
base_url_config (Optional[dict]): Configuration for the base URL of the API.
Expand Down
23 changes: 15 additions & 8 deletions openlrc/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from openlrc.models import Models, ModelInfo, ModelProvider
from openlrc.utils import get_messages_token_number, get_text_token_number

# The default mapping for model name to chatbot class.
model2chatbot = {}


Expand All @@ -51,7 +52,7 @@ def _register_chatbot(cls):
return cls


def route_chatbot(model) -> (type, str):
def route_chatbot(model: str) -> (type, str):
if ':' in model:
chatbot_type, chatbot_model = re.match(r'(.+):(.+)', model).groups()
chatbot_type, chatbot_model = chatbot_type.strip().lower(), chatbot_model.strip()
Expand Down Expand Up @@ -174,7 +175,7 @@ def __str__(self):
@_register_chatbot
class GPTBot(ChatBot):
def __init__(self, model_name='gpt-4o-mini', temperature=1, top_p=1, retry=8, max_async=16, json_mode=False,
fee_limit=0.05, proxy=None, base_url_config=None):
fee_limit=0.05, proxy=None, base_url_config=None, api_key=None):

# clamp temperature to 0-2
temperature = max(0, min(2, temperature))
Expand All @@ -186,7 +187,7 @@ def __init__(self, model_name='gpt-4o-mini', temperature=1, top_p=1, retry=8, ma
super().__init__(model_name, temperature, top_p, retry, max_async, fee_limit, is_beta)

self.async_client = AsyncGPTClient(
api_key=os.environ['OPENAI_API_KEY'],
api_key=api_key or os.environ['OPENAI_API_KEY'],
http_client=httpx.AsyncClient(proxy=proxy),
base_url=base_url_config['openai'] if base_url_config and base_url_config['openai'] else None
)
Expand Down Expand Up @@ -257,16 +258,15 @@ def _get_sleep_time(error):
@_register_chatbot
class ClaudeBot(ChatBot):
def __init__(self, model_name='claude-3-5-sonnet-20241022', temperature=1, top_p=1, retry=8, max_async=16,
fee_limit=0.8,
proxy=None, base_url_config=None):
fee_limit=0.8, proxy=None, base_url_config=None, api_key=None):

# clamp temperature to 0-1
temperature = max(0, min(1, temperature))

super().__init__(model_name, temperature, top_p, retry, max_async, fee_limit)

self.async_client = AsyncAnthropic(
api_key=os.environ['ANTHROPIC_API_KEY'],
api_key=api_key or os.environ['ANTHROPIC_API_KEY'],
http_client=httpx.AsyncClient(
proxy=proxy
),
Expand Down Expand Up @@ -342,14 +342,14 @@ def _get_sleep_time(self, error):
@_register_chatbot
class GeminiBot(ChatBot):
def __init__(self, model_name='gemini-1.5-flash', temperature=1, top_p=1, retry=8, max_async=16, fee_limit=0.8,
proxy=None, base_url_config=None):
proxy=None, base_url_config=None, api_key=None):
self.temperature = max(0, min(1, temperature))

super().__init__(model_name, temperature, top_p, retry, max_async, fee_limit)

self.model_name = model_name

genai.configure(api_key=os.environ['GOOGLE_API_KEY'])
genai.configure(api_key=api_key or os.environ['GOOGLE_API_KEY'])
self.config = GenerationConfig(temperature=self.temperature, top_p=self.top_p)
# Should not block any translation-related content.
self.safety_settings = {
Expand Down Expand Up @@ -428,3 +428,10 @@ async def _create_achat(self, messages: List[Dict], stop_sequences: Optional[Lis
raise ChatBotException('Failed to create a chat.')

return response


provider2chatbot = {
ModelProvider.OPENAI: GPTBot,
ModelProvider.ANTHROPIC: ClaudeBot,
ModelProvider.GOOGLE: GeminiBot
}
6 changes: 4 additions & 2 deletions openlrc/context.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# Copyright (C) 2024. Hao Zheng
# All rights reserved.
import re
from typing import Optional
from typing import Optional, Union

from pydantic import BaseModel

from openlrc import ModelConfig


class TranslationContext(BaseModel):
summary: Optional[str] = ''
scene: Optional[str] = ''
model: Optional[str] = None
model: Optional[Union[str, ModelConfig]] = None
guideline: Optional[str] = None

def update(self, **args):
Expand Down
18 changes: 3 additions & 15 deletions openlrc/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (C) 2024. Hao Zheng
# All rights reserved.
import abc
from typing import Union

from openlrc.agents import TranslationEvaluatorAgent
from openlrc.logger import logger
from openlrc.models import ModelConfig


class TranslationEvaluator(abc.ABC):
Expand All @@ -25,21 +26,8 @@ class LLMTranslationEvaluator(TranslationEvaluator):
Evaluate the translated texts using large language models.
"""

def __init__(self, chatbot_model: str = 'gpt-4o-mini'):
def __init__(self, chatbot_model: Union[str, ModelConfig] = 'gpt-4o-mini'):
self.agenet = TranslationEvaluatorAgent(chatbot_model=chatbot_model)
self.recommended_model = {
'gpt-4',
'claude-3-sonnet',
'claude-3-opus',
'gemini-1.5-pro'
}

for m in self.recommended_model:
if chatbot_model.startswith(m):
self.agenet = TranslationEvaluatorAgent(chatbot_model=chatbot_model)
break
else:
logger.warning(f'Chatbot model {chatbot_model} is not in the recommended list for evaluating translations.')

def evaluate(self, src_texts, target_texts, src_lang=None, target_lang=None):
return self.agenet.evaluate(src_texts, target_texts)
Expand Down
12 changes: 12 additions & 0 deletions openlrc/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@ class ModelProvider(Enum):
THIRD_PARTY = "third_party"


@dataclass
class ModelConfig:
provider: ModelProvider
name: str
base_url: Optional[str] = None
api_key: Optional[str] = None
proxy: Optional[str] = None

def __str__(self):
return f'{self.provider.value}:{self.name}'


@dataclass
class ModelInfo:
name: str
Expand Down
27 changes: 19 additions & 8 deletions openlrc/openlrc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from openlrc.context import TranslateInfo
from openlrc.defaults import default_asr_options, default_vad_options, default_preprocess_options
from openlrc.logger import logger
from openlrc.models import ModelConfig
from openlrc.opt import SubtitleOptimizer
from openlrc.preprocess import Preprocessor
from openlrc.subtitle import Subtitle, BilingualSubtitle
Expand All @@ -36,7 +37,9 @@ class LRCer:
``float16`` or ``float32``. Default: ``float16``
Note: ``default`` will keep the same quantization that was used during model conversion.
device (str): The device to use for computation. Default: ``cuda``
chatbot_model (str): The chatbot model to use, check the available models using list_chatbot_models().
chatbot_model (Union[str, ModelConfig]): The chatbot model to use, check the available models using list_chatbot_models().
The string can be '<model-name>' or '<provider>:<model-name>'. e.g. 'gpt-4o-mini' or 'openai:gpt-4o-mini'.
The ModelConfig can be ModelConfig(model_name='<model-name>', provider='<provider>', base_url='<url>', proxy='<proxy>').
Default: ``gpt-4o-mini``
fee_limit (float): The maximum fee you are willing to pay for one translation call. Default: ``0.8``
consumer_thread (int): To prevent exceeding the RPM and TPM limits set by OpenAI, the default is TPM/MAX_TOKEN.
Expand All @@ -50,16 +53,17 @@ class LRCer:
glossary (Optional[Union[dict, str, Path]]): A dictionary mapping specific source words to their desired translations.
This is used to enforce custom translations that override the default behavior of the translation model.
Each key-value pair in the dictionary specifies a source word and its corresponding translation. Default: None.
retry_model (Optional[str]): The model to use when retrying the translation. Default: None.
retry_model (Optional[Union[str, ModelConfig]]): The model to use when retrying the translation. Default: None.
is_force_glossary_used (bool): Whether to force the given glossary to be used in context. Default: False
"""

def __init__(self, whisper_model: str = 'large-v3', compute_type: str = 'float16', device: str = 'cuda',
chatbot_model: str = 'gpt-4o-mini', fee_limit: float = 0.8, consumer_thread: int = 4,
chatbot_model: Union[str, ModelConfig] = 'gpt-4o-mini', fee_limit: float = 0.8,
consumer_thread: int = 4,
asr_options: Optional[dict] = None, vad_options: Optional[dict] = None,
preprocess_options: Optional[dict] = None, proxy: Optional[str] = None,
base_url_config: Optional[dict] = None, glossary: Optional[Union[dict, str, Path]] = None,
retry_model: Optional[str] = None, is_force_glossary_used: bool = False):
retry_model: Optional[Union[str, ModelConfig]] = None, is_force_glossary_used: bool = False):
self.chatbot_model = chatbot_model
self.fee_limit = fee_limit
self.api_fee = 0 # Can be updated in different thread, operation should be thread-safe
Expand Down Expand Up @@ -424,12 +428,19 @@ def to_json(segments: List[Segment], name, lang):
'segments': []
}

for segment in segments:
if not segments:
result['segments'].append({
'start': segment.start,
'end': segment.end,
'text': segment.text
'start': 0.0,
'end': 5.0,
'text': "no speech found"
})
else:
for segment in segments:
result['segments'].append({
'start': segment.start,
'end': segment.end,
'text': segment.text
})

with open(name, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=4)
Expand Down
Loading

0 comments on commit 6efe333

Please sign in to comment.